|
8 | 8 | from a2a.types.a2a_pb2 import ( |
9 | 9 | Artifact, |
10 | 10 | Part, |
| 11 | + TaskArtifactUpdateEvent, |
11 | 12 | ) |
12 | 13 | from a2a.utils.artifact import ( |
| 14 | + ArtifactStreamer, |
13 | 15 | get_artifact_text, |
14 | 16 | new_artifact, |
15 | 17 | new_data_artifact, |
@@ -157,5 +159,105 @@ def test_get_artifact_text_empty_parts(self): |
157 | 159 | assert result == '' |
158 | 160 |
|
159 | 161 |
|
| 162 | +class TestArtifactStreamer(unittest.TestCase): |
| 163 | + def setUp(self): |
| 164 | + self.context_id = 'ctx-123' |
| 165 | + self.task_id = 'task-456' |
| 166 | + |
| 167 | + def test_generates_stable_artifact_id(self): |
| 168 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 169 | + e1 = streamer.append('hello ') |
| 170 | + e2 = streamer.append('world') |
| 171 | + self.assertEqual(e1.artifact.artifact_id, e2.artifact.artifact_id) |
| 172 | + |
| 173 | + def test_uses_explicit_artifact_id(self): |
| 174 | + streamer = ArtifactStreamer( |
| 175 | + self.context_id, self.task_id, artifact_id='my-fixed-id' |
| 176 | + ) |
| 177 | + event = streamer.append('chunk') |
| 178 | + self.assertEqual(event.artifact.artifact_id, 'my-fixed-id') |
| 179 | + |
| 180 | + @patch('a2a.utils.artifact.uuid.uuid4') |
| 181 | + def test_generated_id_comes_from_uuid4(self, mock_uuid4): |
| 182 | + mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678') |
| 183 | + mock_uuid4.return_value = mock_uuid |
| 184 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 185 | + self.assertEqual(streamer._artifact_id, str(mock_uuid)) |
| 186 | + |
| 187 | + def test_default_name_is_response(self): |
| 188 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 189 | + event = streamer.append('text') |
| 190 | + self.assertEqual(event.artifact.name, 'response') |
| 191 | + |
| 192 | + def test_custom_name(self): |
| 193 | + streamer = ArtifactStreamer( |
| 194 | + self.context_id, self.task_id, name='summary' |
| 195 | + ) |
| 196 | + event = streamer.append('text') |
| 197 | + self.assertEqual(event.artifact.name, 'summary') |
| 198 | + |
| 199 | + def test_append_returns_task_artifact_update_event(self): |
| 200 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 201 | + event = streamer.append('chunk') |
| 202 | + self.assertIsInstance(event, TaskArtifactUpdateEvent) |
| 203 | + |
| 204 | + def test_append_sets_correct_context_and_task(self): |
| 205 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 206 | + event = streamer.append('chunk') |
| 207 | + self.assertEqual(event.context_id, self.context_id) |
| 208 | + self.assertEqual(event.task_id, self.task_id) |
| 209 | + |
| 210 | + def test_append_sets_append_true_last_chunk_false(self): |
| 211 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 212 | + event = streamer.append('chunk') |
| 213 | + self.assertTrue(event.append) |
| 214 | + self.assertFalse(event.last_chunk) |
| 215 | + |
| 216 | + def test_append_creates_single_text_part(self): |
| 217 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 218 | + event = streamer.append('hello') |
| 219 | + self.assertEqual(len(event.artifact.parts), 1) |
| 220 | + self.assertTrue(event.artifact.parts[0].HasField('text')) |
| 221 | + self.assertEqual(event.artifact.parts[0].text, 'hello') |
| 222 | + |
| 223 | + def test_finalize_returns_task_artifact_update_event(self): |
| 224 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 225 | + event = streamer.finalize() |
| 226 | + self.assertIsInstance(event, TaskArtifactUpdateEvent) |
| 227 | + |
| 228 | + def test_finalize_sets_append_true_last_chunk_true(self): |
| 229 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 230 | + event = streamer.finalize() |
| 231 | + self.assertTrue(event.append) |
| 232 | + self.assertTrue(event.last_chunk) |
| 233 | + |
| 234 | + def test_finalize_has_empty_parts(self): |
| 235 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 236 | + event = streamer.finalize() |
| 237 | + self.assertEqual(len(event.artifact.parts), 0) |
| 238 | + |
| 239 | + def test_finalize_uses_same_artifact_id_as_append(self): |
| 240 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 241 | + append_event = streamer.append('text') |
| 242 | + finalize_event = streamer.finalize() |
| 243 | + self.assertEqual( |
| 244 | + append_event.artifact.artifact_id, |
| 245 | + finalize_event.artifact.artifact_id, |
| 246 | + ) |
| 247 | + |
| 248 | + def test_multiple_appends_all_share_artifact_id(self): |
| 249 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 250 | + events = [streamer.append(f'chunk-{i}') for i in range(5)] |
| 251 | + ids = {e.artifact.artifact_id for e in events} |
| 252 | + self.assertEqual(len(ids), 1) |
| 253 | + |
| 254 | + def test_multiple_appends_carry_distinct_text(self): |
| 255 | + streamer = ArtifactStreamer(self.context_id, self.task_id) |
| 256 | + texts = ['Hello, ', 'world', '!'] |
| 257 | + events = [streamer.append(t) for t in texts] |
| 258 | + result_texts = [e.artifact.parts[0].text for e in events] |
| 259 | + self.assertEqual(result_texts, texts) |
| 260 | + |
| 261 | + |
160 | 262 | if __name__ == '__main__': |
161 | 263 | unittest.main() |
0 commit comments