Skip to content

Commit 62e5e59

Browse files
authored
feat: Simplify ActiveTask.subscribe() (#958)
Simplify ActiveTask.subscribe() and remove race condition between _is_finished and slow enqueue. Fixes #869 🦕
1 parent 354fdfb commit 62e5e59

1 file changed

Lines changed: 42 additions & 48 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,14 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
511511
)
512512
except Exception as e:
513513
logger.exception('Consumer[%s]: Failed', self._task_id)
514+
# TODO: Make the task in database as failed.
514515
async with self._lock:
515516
await self._mark_task_as_failed(e)
516517
finally:
517518
# The consumer is dead. The ActiveTask is permanently finished.
518519
self._is_finished.set()
519520
self._request_queue.shutdown(immediate=True)
521+
await self._event_queue_agent.close(immediate=True)
520522

521523
logger.debug('Consumer[%s]: Finishing', self._task_id)
522524
await self._maybe_cleanup()
@@ -574,53 +576,42 @@ async def subscribe( # noqa: PLR0912, PLR0915
574576
if self._exception:
575577
raise self._exception
576578

577-
# Wait for next event or task completion
578-
try:
579-
dequeued = await asyncio.wait_for(
580-
tapped_queue.dequeue_event(), timeout=0.1
581-
)
582-
event, updated_task = cast('Any', dequeued)
579+
dequeued = await tapped_queue.dequeue_event()
580+
event, updated_task = cast('Any', dequeued)
581+
logger.debug(
582+
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
583+
self._task_id,
584+
event,
585+
updated_task,
586+
)
587+
if replace_status_update_with_task and isinstance(
588+
event, TaskStatusUpdateEvent
589+
):
583590
logger.debug(
584-
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
591+
'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s',
585592
self._task_id,
586-
event,
587593
updated_task,
588594
)
589-
if replace_status_update_with_task and isinstance(
590-
event, TaskStatusUpdateEvent
595+
event = updated_task
596+
if self._exception:
597+
raise self._exception from None
598+
if isinstance(event, _RequestCompleted):
599+
if (
600+
request_id is not None
601+
and event.request_id == request_id
591602
):
592603
logger.debug(
593-
'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s',
604+
'Subscriber[%s]: Request completed',
594605
self._task_id,
595-
updated_task,
596606
)
597-
event = updated_task
598-
if self._exception:
599-
raise self._exception from None
600-
if isinstance(event, _RequestCompleted):
601-
if (
602-
request_id is not None
603-
and event.request_id == request_id
604-
):
605-
logger.debug(
606-
'Subscriber[%s]: Request completed',
607-
self._task_id,
608-
)
609-
return
610-
continue
611-
elif isinstance(event, _RequestStarted):
612-
logger.debug(
613-
'Subscriber[%s]: Request started',
614-
self._task_id,
615-
)
616-
continue
617-
except (asyncio.TimeoutError, TimeoutError):
618-
if self._is_finished.is_set():
619-
if self._exception:
620-
raise self._exception from None
621-
break
607+
return
608+
continue
609+
elif isinstance(event, _RequestStarted):
610+
logger.debug(
611+
'Subscriber[%s]: Request started',
612+
self._task_id,
613+
)
622614
continue
623-
624615
try:
625616
yield event
626617
finally:
@@ -715,17 +706,20 @@ async def _mark_task_as_failed(self, exception: Exception) -> None:
715706
if self._exception is None:
716707
self._exception = exception
717708
if self._task_created.is_set():
718-
task = await self._task_manager.get_task()
719-
if task is not None:
720-
await self._event_queue_agent.enqueue_event(
721-
TaskStatusUpdateEvent(
722-
task_id=task.id,
723-
context_id=task.context_id,
724-
status=TaskStatus(
725-
state=TaskState.TASK_STATE_FAILED,
726-
),
709+
try:
710+
task = await self._task_manager.get_task()
711+
if task is not None:
712+
await self._event_queue_agent.enqueue_event(
713+
TaskStatusUpdateEvent(
714+
task_id=task.id,
715+
context_id=task.context_id,
716+
status=TaskStatus(
717+
state=TaskState.TASK_STATE_FAILED,
718+
),
719+
)
727720
)
728-
)
721+
except QueueShutDown:
722+
pass
729723

730724
async def get_task(self) -> Task:
731725
"""Get task from db."""

0 commit comments

Comments
 (0)