From 7b15660878620aa5f23b9cc1af9f7abc3c9600e3 Mon Sep 17 00:00:00 2001 From: "Kevin(Kefa) Lu" Date: Thu, 5 Mar 2026 16:58:23 -0800 Subject: [PATCH 1/3] fix: return background task from consume_and_break_on_interrupt to prevent GC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ResultAggregator.consume_and_break_on_interrupt creates a background asyncio.Task to continue consuming events after an interruption (non-blocking or auth_required), but discards the task reference. On Python 3.12+ the event loop only holds weak references to tasks, so the garbage collector can silently collect the task before it completes — dropping remaining events (completed/failed status) and push notification callbacks. Return the background task as a third tuple element so callers can hold a strong reference. DefaultRequestHandler.on_message_send now tracks it via _track_background_task(), the same mechanism already used for other background work. --- .../request_handlers/default_request_handler.py | 7 +++++++ src/a2a/server/tasks/result_aggregator.py | 17 ++++++++++++----- .../test_default_request_handler.py | 8 +++++++- tests/server/tasks/test_result_aggregator.py | 14 +++++++++++++- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index cb002569e..b8bc00dce 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -322,12 +322,19 @@ async def push_notification_callback() -> None: ( result, interrupted_or_non_blocking, + bg_consume_task, ) = await result_aggregator.consume_and_break_on_interrupt( consumer, blocking=blocking, event_callback=push_notification_callback, ) + if bg_consume_task is not None: + bg_consume_task.set_name( + f'continue_consuming:{task_id}' + ) + self._track_background_task(bg_consume_task) + except Exception: logger.exception('Agent execution failed') producer_task.cancel() diff --git a/src/a2a/server/tasks/result_aggregator.py b/src/a2a/server/tasks/result_aggregator.py index fb1ab62ef..8c424bda7 100644 --- a/src/a2a/server/tasks/result_aggregator.py +++ b/src/a2a/server/tasks/result_aggregator.py @@ -99,7 +99,7 @@ async def consume_and_break_on_interrupt( consumer: EventConsumer, blocking: bool = True, event_callback: Callable[[], Awaitable[None]] | None = None, - ) -> tuple[Task | Message | None, bool]: + ) -> tuple[Task | Message | None, bool, asyncio.Task | None]: """Processes the event stream until completion or an interruptable state is encountered. If `blocking` is False, it returns after the first event that creates a Task or Message. @@ -119,16 +119,23 @@ async def consume_and_break_on_interrupt( A tuple containing: - The current aggregated result (`Task` or `Message`) at the point of completion or interruption. - A boolean indicating whether the consumption was interrupted (`True`) or completed naturally (`False`). + - The background ``asyncio.Task`` that continues consuming events + after an interruption, or ``None`` when no background work was + spawned. **Callers must hold a strong reference** to this task + (e.g. in a ``set``) to prevent the garbage collector from + collecting it before it finishes — the event loop only keeps + weak references to tasks. Raises: BaseException: If the `EventConsumer` raises an exception during consumption. """ event_stream = consumer.consume_all() interrupted = False + bg_task: asyncio.Task | None = None async for event in event_stream: if isinstance(event, Message): self._message = event - return event, False + return event, False, None await self.task_manager.process(event) should_interrupt = False @@ -158,13 +165,13 @@ async def consume_and_break_on_interrupt( if should_interrupt: # Continue consuming the rest of the events in the background. - # TODO: We should track all outstanding tasks to ensure they eventually complete. - asyncio.create_task( # noqa: RUF006 + # The caller is responsible for tracking this task to prevent GC. + bg_task = asyncio.create_task( self._continue_consuming(event_stream, event_callback) ) interrupted = True break - return await self.task_manager.get_task(), interrupted + return await self.task_manager.get_task(), interrupted, bg_task async def _continue_consuming( self, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 067b8bb57..4d0d4b12c 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -421,6 +421,7 @@ async def test_on_message_send_with_push_notification(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( final_task_result, False, + None, ) # Mock the current_result property to return the final task result @@ -520,6 +521,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( initial_task, True, # interrupted = True for non-blocking + MagicMock(spec=asyncio.Task), # background task ) # Mock the current_result property to return the final task @@ -540,7 +542,7 @@ async def mock_consume_and_break_on_interrupt( nonlocal event_callback_passed, event_callback_received event_callback_passed = event_callback is not None event_callback_received = event_callback - return initial_task, True # interrupted = True for non-blocking + return initial_task, True, MagicMock(spec=asyncio.Task) # interrupted = True for non-blocking mock_result_aggregator_instance.consume_and_break_on_interrupt = ( mock_consume_and_break_on_interrupt @@ -631,6 +633,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( final_task_result, False, + None, ) # Mock the current_result property to return the final task result @@ -689,6 +692,7 @@ async def test_on_message_send_no_result_from_aggregator(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( None, False, + None, ) from a2a.utils.errors import ServerError # Local import @@ -740,6 +744,7 @@ async def test_on_message_send_task_id_mismatch(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( mismatched_task, False, + None, ) from a2a.utils.errors import ServerError # Local import @@ -950,6 +955,7 @@ async def test_on_message_send_interrupted_flow(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( interrupt_task_result, True, + MagicMock(spec=asyncio.Task), # background task ) # Interrupted = True # Patch asyncio.create_task to verify _cleanup_producer is scheduled diff --git a/tests/server/tasks/test_result_aggregator.py b/tests/server/tasks/test_result_aggregator.py index bc970246b..d06d27db8 100644 --- a/tests/server/tasks/test_result_aggregator.py +++ b/tests/server/tasks/test_result_aggregator.py @@ -228,12 +228,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) self.assertEqual(result, sample_message) self.assertFalse(interrupted) + self.assertIsNone(bg_task) self.mock_task_manager.process.assert_not_called() # Process is not called for the Message if returned directly # _continue_consuming should not be called if it's a message interrupt # and no auth_required state. @@ -260,17 +262,21 @@ async def mock_consume_generator(): # Mock _continue_consuming to check if it's called by create_task self.aggregator._continue_consuming = AsyncMock() + sentinel_task = asyncio.ensure_future(asyncio.sleep(0)) + mock_create_task.return_value = sentinel_task mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro) ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) self.assertEqual(result, auth_task) self.assertTrue(interrupted) + self.assertIsNotNone(bg_task) self.mock_task_manager.process.assert_called_once_with(auth_task) mock_create_task.assert_called_once() # Check that create_task was called # self.aggregator._continue_consuming is an AsyncMock. @@ -317,12 +323,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) self.assertEqual(result, current_task_state_after_update) self.assertTrue(interrupted) + self.assertIsNotNone(bg_task) self.mock_task_manager.process.assert_called_once_with( auth_status_update ) @@ -353,6 +361,7 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) @@ -360,6 +369,7 @@ async def mock_consume_generator(): # If the first event is a Message, it's returned directly. self.assertEqual(result, event1) self.assertFalse(interrupted) + self.assertIsNone(bg_task) # process() is NOT called for the Message if it's the one causing the return self.mock_task_manager.process.assert_not_called() self.mock_task_manager.get_task.assert_not_called() @@ -415,12 +425,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer, blocking=False ) self.assertEqual(result, first_event) self.assertTrue(interrupted) + self.assertIsNotNone(bg_task) self.mock_task_manager.process.assert_called_once_with(first_event) mock_create_task.assert_called_once() # The background task should be created with the remaining stream @@ -468,7 +480,7 @@ async def initial_consume_generator(): mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro) # Call the main method that triggers _continue_consuming via create_task - _, _ = await self.aggregator.consume_and_break_on_interrupt( + _, _, _ = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) From 4b3828b3bc1fdc10d329d7bc56a4d01dcd1cd96c Mon Sep 17 00:00:00 2001 From: "Kevin(Kefa) Lu" Date: Thu, 5 Mar 2026 18:00:46 -0800 Subject: [PATCH 2/3] fix formatting --- src/a2a/server/request_handlers/default_request_handler.py | 4 +--- .../server/request_handlers/test_default_request_handler.py | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index b8bc00dce..3bd6a0dc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -330,9 +330,7 @@ async def push_notification_callback() -> None: ) if bg_consume_task is not None: - bg_consume_task.set_name( - f'continue_consuming:{task_id}' - ) + bg_consume_task.set_name(f'continue_consuming:{task_id}') self._track_background_task(bg_consume_task) except Exception: diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 4d0d4b12c..ec2956fa2 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -542,7 +542,11 @@ async def mock_consume_and_break_on_interrupt( nonlocal event_callback_passed, event_callback_received event_callback_passed = event_callback is not None event_callback_received = event_callback - return initial_task, True, MagicMock(spec=asyncio.Task) # interrupted = True for non-blocking + return ( + initial_task, + True, + MagicMock(spec=asyncio.Task), + ) # interrupted = True for non-blocking mock_result_aggregator_instance.consume_and_break_on_interrupt = ( mock_consume_and_break_on_interrupt From a22e2c73fc52752523b6c59a975ad362d34ea069 Mon Sep 17 00:00:00 2001 From: "Kevin(Kefa) Lu" Date: Thu, 5 Mar 2026 18:07:29 -0800 Subject: [PATCH 3/3] Removed the redundant sentinel_task and mock_create_task.return_value assignment --- tests/server/tasks/test_result_aggregator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/server/tasks/test_result_aggregator.py b/tests/server/tasks/test_result_aggregator.py index d06d27db8..7b29ea4c8 100644 --- a/tests/server/tasks/test_result_aggregator.py +++ b/tests/server/tasks/test_result_aggregator.py @@ -262,8 +262,6 @@ async def mock_consume_generator(): # Mock _continue_consuming to check if it's called by create_task self.aggregator._continue_consuming = AsyncMock() - sentinel_task = asyncio.ensure_future(asyncio.sleep(0)) - mock_create_task.return_value = sentinel_task mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro) (