|
32 | 32 | Message, |
33 | 33 | Task, |
34 | 34 | TaskState, |
| 35 | + TaskStatus, |
| 36 | + TaskStatusUpdateEvent, |
35 | 37 | ) |
36 | 38 | from a2a.utils.errors import ( |
37 | 39 | InvalidParamsError, |
@@ -252,80 +254,75 @@ async def _run_producer(self) -> None: |
252 | 254 | """ |
253 | 255 | logger.debug('Producer[%s]: Started', self._task_id) |
254 | 256 | try: |
255 | | - try: |
256 | | - try: |
257 | | - while True: |
258 | | - ( |
259 | | - request_context, |
260 | | - request_id, |
261 | | - ) = await self._request_queue.get() |
262 | | - await self._request_lock.acquire() |
263 | | - # TODO: Should we create task manager every time? |
264 | | - self._task_manager._call_context = ( |
265 | | - request_context.call_context |
266 | | - ) |
267 | | - request_context.current_task = ( |
268 | | - await self._task_manager.get_task() |
269 | | - ) |
| 257 | + active = True |
| 258 | + while active: |
| 259 | + ( |
| 260 | + request_context, |
| 261 | + request_id, |
| 262 | + ) = await self._request_queue.get() |
| 263 | + await self._request_lock.acquire() |
| 264 | + # TODO: Should we create task manager every time? |
| 265 | + self._task_manager._call_context = request_context.call_context |
| 266 | + request_context.current_task = ( |
| 267 | + await self._task_manager.get_task() |
| 268 | + ) |
270 | 269 |
|
271 | | - message = request_context.message |
272 | | - if message: |
273 | | - request_context.current_task = ( |
274 | | - self._task_manager.update_with_message( |
275 | | - message, |
276 | | - cast('Task', request_context.current_task), |
277 | | - ) |
278 | | - ) |
279 | | - await self._task_manager.save_task_event( |
280 | | - request_context.current_task |
281 | | - ) |
282 | | - self._task_created.set() |
283 | | - logger.debug( |
284 | | - 'Producer[%s]: Executing agent task %s', |
285 | | - self._task_id, |
286 | | - request_context.current_task, |
| 270 | + message = request_context.message |
| 271 | + if message: |
| 272 | + request_context.current_task = ( |
| 273 | + self._task_manager.update_with_message( |
| 274 | + message, |
| 275 | + cast('Task', request_context.current_task), |
287 | 276 | ) |
| 277 | + ) |
| 278 | + await self._task_manager.save_task_event( |
| 279 | + request_context.current_task |
| 280 | + ) |
| 281 | + self._task_created.set() |
| 282 | + logger.debug( |
| 283 | + 'Producer[%s]: Executing agent task %s', |
| 284 | + self._task_id, |
| 285 | + request_context.current_task, |
| 286 | + ) |
288 | 287 |
|
289 | | - try: |
290 | | - await self._agent_executor.execute( |
291 | | - request_context, self._event_queue_agent |
292 | | - ) |
293 | | - logger.debug( |
294 | | - 'Producer[%s]: Execution finished successfully', |
295 | | - self._task_id, |
296 | | - ) |
297 | | - except Exception as e: |
298 | | - async with self._lock: |
299 | | - if self._exception is None: |
300 | | - self._exception = e |
301 | | - raise |
302 | | - finally: |
303 | | - logger.debug( |
304 | | - 'Producer[%s]: Enqueuing request completed event', |
305 | | - self._task_id, |
306 | | - ) |
307 | | - # TODO: Hide from external consumers |
308 | | - await self._event_queue_agent.enqueue_event( |
309 | | - cast('Event', _RequestCompleted(request_id)) |
310 | | - ) |
311 | | - self._request_queue.task_done() |
| 288 | + try: |
| 289 | + await self._agent_executor.execute( |
| 290 | + request_context, self._event_queue_agent |
| 291 | + ) |
| 292 | + logger.debug( |
| 293 | + 'Producer[%s]: Execution finished successfully', |
| 294 | + self._task_id, |
| 295 | + ) |
312 | 296 | except QueueShutDown: |
313 | 297 | logger.debug( |
314 | 298 | 'Producer[%s]: Request queue shut down', self._task_id |
315 | 299 | ) |
316 | | - except asyncio.CancelledError: |
317 | | - logger.debug('Producer[%s]: Cancelled', self._task_id) |
318 | | - raise |
319 | | - except Exception as e: |
320 | | - logger.exception('Producer[%s]: Failed', self._task_id) |
321 | | - async with self._lock: |
322 | | - if self._exception is None: |
323 | | - self._exception = e |
324 | | - finally: |
325 | | - self._request_queue.shutdown(immediate=True) |
326 | | - await self._event_queue_agent.close(immediate=False) |
327 | | - await self._event_queue_subscribers.close(immediate=False) |
| 300 | + raise |
| 301 | + except asyncio.CancelledError: |
| 302 | + logger.debug('Producer[%s]: Cancelled', self._task_id) |
| 303 | + raise |
| 304 | + except Exception as e: |
| 305 | + logger.exception( |
| 306 | + 'Producer[%s]: Execution failed', |
| 307 | + self._task_id, |
| 308 | + ) |
| 309 | + async with self._lock: |
| 310 | + await self._mark_task_as_failed(e) |
| 311 | + active = False |
| 312 | + finally: |
| 313 | + logger.debug( |
| 314 | + 'Producer[%s]: Enqueuing request completed event', |
| 315 | + self._task_id, |
| 316 | + ) |
| 317 | + # TODO: Hide from external consumers |
| 318 | + await self._event_queue_agent.enqueue_event( |
| 319 | + cast('Event', _RequestCompleted(request_id)) |
| 320 | + ) |
| 321 | + self._request_queue.task_done() |
328 | 322 | finally: |
| 323 | + self._request_queue.shutdown(immediate=True) |
| 324 | + await self._event_queue_agent.close(immediate=False) |
| 325 | + await self._event_queue_subscribers.close(immediate=False) |
329 | 326 | logger.debug('Producer[%s]: Completed', self._task_id) |
330 | 327 |
|
331 | 328 | async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 |
@@ -443,8 +440,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 |
443 | 440 | except Exception as e: |
444 | 441 | logger.exception('Consumer[%s]: Failed', self._task_id) |
445 | 442 | async with self._lock: |
446 | | - if self._exception is None: |
447 | | - self._exception = e |
| 443 | + await self._mark_task_as_failed(e) |
448 | 444 | finally: |
449 | 445 | # The consumer is dead. The ActiveTask is permanently finished. |
450 | 446 | self._is_finished.set() |
@@ -581,9 +577,7 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: |
581 | 577 | logger.exception( |
582 | 578 | 'Cancel[%s]: Agent cancel failed', self._task_id |
583 | 579 | ) |
584 | | - if not self._exception: |
585 | | - self._exception = e |
586 | | - |
| 580 | + await self._mark_task_as_failed(e) |
587 | 581 | raise |
588 | 582 | else: |
589 | 583 | logger.debug( |
@@ -619,6 +613,22 @@ async def _maybe_cleanup(self) -> None: |
619 | 613 | logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id) |
620 | 614 | self._on_cleanup(self) |
621 | 615 |
|
| 616 | + async def _mark_task_as_failed(self, exception: Exception) -> None: |
| 617 | + if self._exception is None: |
| 618 | + self._exception = exception |
| 619 | + if self._task_created.is_set(): |
| 620 | + task = await self._task_manager.get_task() |
| 621 | + if task is not None: |
| 622 | + await self._event_queue_agent.enqueue_event( |
| 623 | + TaskStatusUpdateEvent( |
| 624 | + task_id=task.id, |
| 625 | + context_id=task.context_id, |
| 626 | + status=TaskStatus( |
| 627 | + state=TaskState.TASK_STATE_FAILED, |
| 628 | + ), |
| 629 | + ) |
| 630 | + ) |
| 631 | + |
622 | 632 | async def get_task(self) -> Task: |
623 | 633 | """Get task from db.""" |
624 | 634 | # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation). |
|
0 commit comments