Skip to content

Commit 4fc6b54

Browse files
authored
feat: Unhandled exception in AgentExecutor marks task as failed (#943)
Fixes #869 🦕
1 parent a61f6d4 commit 4fc6b54

2 files changed

Lines changed: 97 additions & 79 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
Message,
3333
Task,
3434
TaskState,
35+
TaskStatus,
36+
TaskStatusUpdateEvent,
3537
)
3638
from a2a.utils.errors import (
3739
InvalidParamsError,
@@ -252,80 +254,75 @@ async def _run_producer(self) -> None:
252254
"""
253255
logger.debug('Producer[%s]: Started', self._task_id)
254256
try:
255-
try:
256-
try:
257-
while True:
258-
(
259-
request_context,
260-
request_id,
261-
) = await self._request_queue.get()
262-
await self._request_lock.acquire()
263-
# TODO: Should we create task manager every time?
264-
self._task_manager._call_context = (
265-
request_context.call_context
266-
)
267-
request_context.current_task = (
268-
await self._task_manager.get_task()
269-
)
257+
active = True
258+
while active:
259+
(
260+
request_context,
261+
request_id,
262+
) = await self._request_queue.get()
263+
await self._request_lock.acquire()
264+
# TODO: Should we create task manager every time?
265+
self._task_manager._call_context = request_context.call_context
266+
request_context.current_task = (
267+
await self._task_manager.get_task()
268+
)
270269

271-
message = request_context.message
272-
if message:
273-
request_context.current_task = (
274-
self._task_manager.update_with_message(
275-
message,
276-
cast('Task', request_context.current_task),
277-
)
278-
)
279-
await self._task_manager.save_task_event(
280-
request_context.current_task
281-
)
282-
self._task_created.set()
283-
logger.debug(
284-
'Producer[%s]: Executing agent task %s',
285-
self._task_id,
286-
request_context.current_task,
270+
message = request_context.message
271+
if message:
272+
request_context.current_task = (
273+
self._task_manager.update_with_message(
274+
message,
275+
cast('Task', request_context.current_task),
287276
)
277+
)
278+
await self._task_manager.save_task_event(
279+
request_context.current_task
280+
)
281+
self._task_created.set()
282+
logger.debug(
283+
'Producer[%s]: Executing agent task %s',
284+
self._task_id,
285+
request_context.current_task,
286+
)
288287

289-
try:
290-
await self._agent_executor.execute(
291-
request_context, self._event_queue_agent
292-
)
293-
logger.debug(
294-
'Producer[%s]: Execution finished successfully',
295-
self._task_id,
296-
)
297-
except Exception as e:
298-
async with self._lock:
299-
if self._exception is None:
300-
self._exception = e
301-
raise
302-
finally:
303-
logger.debug(
304-
'Producer[%s]: Enqueuing request completed event',
305-
self._task_id,
306-
)
307-
# TODO: Hide from external consumers
308-
await self._event_queue_agent.enqueue_event(
309-
cast('Event', _RequestCompleted(request_id))
310-
)
311-
self._request_queue.task_done()
288+
try:
289+
await self._agent_executor.execute(
290+
request_context, self._event_queue_agent
291+
)
292+
logger.debug(
293+
'Producer[%s]: Execution finished successfully',
294+
self._task_id,
295+
)
312296
except QueueShutDown:
313297
logger.debug(
314298
'Producer[%s]: Request queue shut down', self._task_id
315299
)
316-
except asyncio.CancelledError:
317-
logger.debug('Producer[%s]: Cancelled', self._task_id)
318-
raise
319-
except Exception as e:
320-
logger.exception('Producer[%s]: Failed', self._task_id)
321-
async with self._lock:
322-
if self._exception is None:
323-
self._exception = e
324-
finally:
325-
self._request_queue.shutdown(immediate=True)
326-
await self._event_queue_agent.close(immediate=False)
327-
await self._event_queue_subscribers.close(immediate=False)
300+
raise
301+
except asyncio.CancelledError:
302+
logger.debug('Producer[%s]: Cancelled', self._task_id)
303+
raise
304+
except Exception as e:
305+
logger.exception(
306+
'Producer[%s]: Execution failed',
307+
self._task_id,
308+
)
309+
async with self._lock:
310+
await self._mark_task_as_failed(e)
311+
active = False
312+
finally:
313+
logger.debug(
314+
'Producer[%s]: Enqueuing request completed event',
315+
self._task_id,
316+
)
317+
# TODO: Hide from external consumers
318+
await self._event_queue_agent.enqueue_event(
319+
cast('Event', _RequestCompleted(request_id))
320+
)
321+
self._request_queue.task_done()
328322
finally:
323+
self._request_queue.shutdown(immediate=True)
324+
await self._event_queue_agent.close(immediate=False)
325+
await self._event_queue_subscribers.close(immediate=False)
329326
logger.debug('Producer[%s]: Completed', self._task_id)
330327

331328
async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
@@ -443,8 +440,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
443440
except Exception as e:
444441
logger.exception('Consumer[%s]: Failed', self._task_id)
445442
async with self._lock:
446-
if self._exception is None:
447-
self._exception = e
443+
await self._mark_task_as_failed(e)
448444
finally:
449445
# The consumer is dead. The ActiveTask is permanently finished.
450446
self._is_finished.set()
@@ -581,9 +577,7 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
581577
logger.exception(
582578
'Cancel[%s]: Agent cancel failed', self._task_id
583579
)
584-
if not self._exception:
585-
self._exception = e
586-
580+
await self._mark_task_as_failed(e)
587581
raise
588582
else:
589583
logger.debug(
@@ -619,6 +613,22 @@ async def _maybe_cleanup(self) -> None:
619613
logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id)
620614
self._on_cleanup(self)
621615

616+
async def _mark_task_as_failed(self, exception: Exception) -> None:
617+
if self._exception is None:
618+
self._exception = exception
619+
if self._task_created.is_set():
620+
task = await self._task_manager.get_task()
621+
if task is not None:
622+
await self._event_queue_agent.enqueue_event(
623+
TaskStatusUpdateEvent(
624+
task_id=task.id,
625+
context_id=task.context_id,
626+
status=TaskStatus(
627+
state=TaskState.TASK_STATE_FAILED,
628+
),
629+
)
630+
)
631+
622632
async def get_task(self) -> Task:
623633
"""Get task from db."""
624634
# TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation).

tests/integration/test_scenarios.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,8 @@ async def cancel(
437437
# Legacy is not creating tasks for agent failures.
438438
assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0
439439
else:
440-
# TODO: should it be TASK_STATE_FAILED ?
441440
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
442-
assert task.status.state == TaskState.TASK_STATE_SUBMITTED
441+
assert task.status.state == TaskState.TASK_STATE_FAILED
443442

444443

445444
# Scenario 12/13: Exception after initial event
@@ -503,9 +502,12 @@ async def release_agent():
503502

504503
await asyncio.gather(*tasks)
505504

506-
# TODO: should it be TASK_STATE_FAILED ?
507505
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
508-
assert task.status.state == TaskState.TASK_STATE_WORKING
506+
if use_legacy:
507+
# Legacy does not update task state on exception.
508+
assert task.status.state == TaskState.TASK_STATE_WORKING
509+
else:
510+
assert task.status.state == TaskState.TASK_STATE_FAILED
509511

510512

511513
# Scenario 14: Exception in Cancel
@@ -563,9 +565,12 @@ async def cancel(
563565
with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'):
564566
await client.cancel_task(CancelTaskRequest(id=task_id))
565567

566-
# TODO: should it be TASK_STATE_CANCELED or TASK_STATE_FAILED?
567568
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
568-
assert task.status.state == TaskState.TASK_STATE_WORKING
569+
if use_legacy:
570+
# Legacy does not update task state on exception.
571+
assert task.status.state == TaskState.TASK_STATE_WORKING
572+
else:
573+
assert task.status.state == TaskState.TASK_STATE_FAILED
569574

570575

571576
# Scenario 15: Subscribe to task that errors out
@@ -632,9 +637,12 @@ async def consume_events():
632637
with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'):
633638
await consume_task
634639

635-
# TODO: should it be TASK_STATE_FAILED?
636640
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
637-
assert task.status.state == TaskState.TASK_STATE_WORKING
641+
if use_legacy:
642+
# Legacy does not update task state on exception.
643+
assert task.status.state == TaskState.TASK_STATE_WORKING
644+
else:
645+
assert task.status.state == TaskState.TASK_STATE_FAILED
638646

639647

640648
# Scenario 16: Slow execution and return_immediately=True

0 commit comments

Comments
 (0)