Skip to content

Commit a236d4d

Browse files
kevinlu310ishymko
andauthored
fix: return background task from consume_and_break_on_interrupt to prevent GC (#775)
# Description ResultAggregator.consume_and_break_on_interrupt creates a background asyncio.Task to continue consuming events after an interruption (non-blocking or auth_required), but discards the task reference. On Python 3.12+ the event loop only holds weak references to tasks, so the garbage collector can silently collect the task before it completes — dropping remaining events (completed/failed status) and push notification callbacks. Return the background task as a third tuple element so callers can hold a strong reference. DefaultRequestHandler.on_message_send now tracks it via _track_background_task(), the same mechanism already used for other background work. - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) Fixes #774 🦕 --------- Co-authored-by: Ivan Shymko <ishymko@google.com>
1 parent fa14dbf commit a236d4d

4 files changed

Lines changed: 39 additions & 7 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,17 @@ async def push_notification_callback() -> None:
322322
(
323323
result,
324324
interrupted_or_non_blocking,
325+
bg_consume_task,
325326
) = await result_aggregator.consume_and_break_on_interrupt(
326327
consumer,
327328
blocking=blocking,
328329
event_callback=push_notification_callback,
329330
)
330331

332+
if bg_consume_task is not None:
333+
bg_consume_task.set_name(f'continue_consuming:{task_id}')
334+
self._track_background_task(bg_consume_task)
335+
331336
except Exception:
332337
logger.exception('Agent execution failed')
333338
producer_task.cancel()

src/a2a/server/tasks/result_aggregator.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def consume_and_break_on_interrupt(
9999
consumer: EventConsumer,
100100
blocking: bool = True,
101101
event_callback: Callable[[], Awaitable[None]] | None = None,
102-
) -> tuple[Task | Message | None, bool]:
102+
) -> tuple[Task | Message | None, bool, asyncio.Task | None]:
103103
"""Processes the event stream until completion or an interruptable state is encountered.
104104
105105
If `blocking` is False, it returns after the first event that creates a Task or Message.
@@ -119,16 +119,23 @@ async def consume_and_break_on_interrupt(
119119
A tuple containing:
120120
- The current aggregated result (`Task` or `Message`) at the point of completion or interruption.
121121
- A boolean indicating whether the consumption was interrupted (`True`) or completed naturally (`False`).
122+
- The background ``asyncio.Task`` that continues consuming events
123+
after an interruption, or ``None`` when no background work was
124+
spawned. **Callers must hold a strong reference** to this task
125+
(e.g. in a ``set``) to prevent the garbage collector from
126+
collecting it before it finishes — the event loop only keeps
127+
weak references to tasks.
122128
123129
Raises:
124130
BaseException: If the `EventConsumer` raises an exception during consumption.
125131
"""
126132
event_stream = consumer.consume_all()
127133
interrupted = False
134+
bg_task: asyncio.Task | None = None
128135
async for event in event_stream:
129136
if isinstance(event, Message):
130137
self._message = event
131-
return event, False
138+
return event, False, None
132139
await self.task_manager.process(event)
133140

134141
should_interrupt = False
@@ -158,13 +165,13 @@ async def consume_and_break_on_interrupt(
158165

159166
if should_interrupt:
160167
# Continue consuming the rest of the events in the background.
161-
# TODO: We should track all outstanding tasks to ensure they eventually complete.
162-
asyncio.create_task( # noqa: RUF006
168+
# The caller is responsible for tracking this task to prevent GC.
169+
bg_task = asyncio.create_task(
163170
self._continue_consuming(event_stream, event_callback)
164171
)
165172
interrupted = True
166173
break
167-
return await self.task_manager.get_task(), interrupted
174+
return await self.task_manager.get_task(), interrupted, bg_task
168175

169176
async def _continue_consuming(
170177
self,

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ async def test_on_message_send_with_push_notification():
421421
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
422422
final_task_result,
423423
False,
424+
None,
424425
)
425426

426427
# Mock the current_result property to return the final task result
@@ -520,6 +521,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request():
520521
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
521522
initial_task,
522523
True, # interrupted = True for non-blocking
524+
MagicMock(spec=asyncio.Task), # background task
523525
)
524526

525527
# Mock the current_result property to return the final task
@@ -540,7 +542,11 @@ async def mock_consume_and_break_on_interrupt(
540542
nonlocal event_callback_passed, event_callback_received
541543
event_callback_passed = event_callback is not None
542544
event_callback_received = event_callback
543-
return initial_task, True # interrupted = True for non-blocking
545+
return (
546+
initial_task,
547+
True,
548+
MagicMock(spec=asyncio.Task),
549+
) # interrupted = True for non-blocking
544550

545551
mock_result_aggregator_instance.consume_and_break_on_interrupt = (
546552
mock_consume_and_break_on_interrupt
@@ -631,6 +637,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task():
631637
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
632638
final_task_result,
633639
False,
640+
None,
634641
)
635642

636643
# Mock the current_result property to return the final task result
@@ -689,6 +696,7 @@ async def test_on_message_send_no_result_from_aggregator():
689696
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
690697
None,
691698
False,
699+
None,
692700
)
693701

694702
from a2a.utils.errors import ServerError # Local import
@@ -740,6 +748,7 @@ async def test_on_message_send_task_id_mismatch():
740748
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
741749
mismatched_task,
742750
False,
751+
None,
743752
)
744753

745754
from a2a.utils.errors import ServerError # Local import
@@ -950,6 +959,7 @@ async def test_on_message_send_interrupted_flow():
950959
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
951960
interrupt_task_result,
952961
True,
962+
MagicMock(spec=asyncio.Task), # background task
953963
) # Interrupted = True
954964

955965
# Patch asyncio.create_task to verify _cleanup_producer is scheduled

tests/server/tasks/test_result_aggregator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,14 @@ async def mock_consume_generator():
228228
(
229229
result,
230230
interrupted,
231+
bg_task,
231232
) = await self.aggregator.consume_and_break_on_interrupt(
232233
self.mock_event_consumer
233234
)
234235

235236
self.assertEqual(result, sample_message)
236237
self.assertFalse(interrupted)
238+
self.assertIsNone(bg_task)
237239
self.mock_task_manager.process.assert_not_called() # Process is not called for the Message if returned directly
238240
# _continue_consuming should not be called if it's a message interrupt
239241
# and no auth_required state.
@@ -265,12 +267,14 @@ async def mock_consume_generator():
265267
(
266268
result,
267269
interrupted,
270+
bg_task,
268271
) = await self.aggregator.consume_and_break_on_interrupt(
269272
self.mock_event_consumer
270273
)
271274

272275
self.assertEqual(result, auth_task)
273276
self.assertTrue(interrupted)
277+
self.assertIsNotNone(bg_task)
274278
self.mock_task_manager.process.assert_called_once_with(auth_task)
275279
mock_create_task.assert_called_once() # Check that create_task was called
276280
# self.aggregator._continue_consuming is an AsyncMock.
@@ -317,12 +321,14 @@ async def mock_consume_generator():
317321
(
318322
result,
319323
interrupted,
324+
bg_task,
320325
) = await self.aggregator.consume_and_break_on_interrupt(
321326
self.mock_event_consumer
322327
)
323328

324329
self.assertEqual(result, current_task_state_after_update)
325330
self.assertTrue(interrupted)
331+
self.assertIsNotNone(bg_task)
326332
self.mock_task_manager.process.assert_called_once_with(
327333
auth_status_update
328334
)
@@ -353,13 +359,15 @@ async def mock_consume_generator():
353359
(
354360
result,
355361
interrupted,
362+
bg_task,
356363
) = await self.aggregator.consume_and_break_on_interrupt(
357364
self.mock_event_consumer
358365
)
359366

360367
# If the first event is a Message, it's returned directly.
361368
self.assertEqual(result, event1)
362369
self.assertFalse(interrupted)
370+
self.assertIsNone(bg_task)
363371
# process() is NOT called for the Message if it's the one causing the return
364372
self.mock_task_manager.process.assert_not_called()
365373
self.mock_task_manager.get_task.assert_not_called()
@@ -415,12 +423,14 @@ async def mock_consume_generator():
415423
(
416424
result,
417425
interrupted,
426+
bg_task,
418427
) = await self.aggregator.consume_and_break_on_interrupt(
419428
self.mock_event_consumer, blocking=False
420429
)
421430

422431
self.assertEqual(result, first_event)
423432
self.assertTrue(interrupted)
433+
self.assertIsNotNone(bg_task)
424434
self.mock_task_manager.process.assert_called_once_with(first_event)
425435
mock_create_task.assert_called_once()
426436
# The background task should be created with the remaining stream
@@ -468,7 +478,7 @@ async def initial_consume_generator():
468478
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)
469479

470480
# Call the main method that triggers _continue_consuming via create_task
471-
_, _ = await self.aggregator.consume_and_break_on_interrupt(
481+
_, _, _ = await self.aggregator.consume_and_break_on_interrupt(
472482
self.mock_event_consumer
473483
)
474484

0 commit comments

Comments
 (0)