diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index cb002569e..3bd6a0dc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -322,12 +322,17 @@ async def push_notification_callback() -> None: ( result, interrupted_or_non_blocking, + bg_consume_task, ) = await result_aggregator.consume_and_break_on_interrupt( consumer, blocking=blocking, event_callback=push_notification_callback, ) + if bg_consume_task is not None: + bg_consume_task.set_name(f'continue_consuming:{task_id}') + self._track_background_task(bg_consume_task) + except Exception: logger.exception('Agent execution failed') producer_task.cancel() diff --git a/src/a2a/server/tasks/result_aggregator.py b/src/a2a/server/tasks/result_aggregator.py index fb1ab62ef..8c424bda7 100644 --- a/src/a2a/server/tasks/result_aggregator.py +++ b/src/a2a/server/tasks/result_aggregator.py @@ -99,7 +99,7 @@ async def consume_and_break_on_interrupt( consumer: EventConsumer, blocking: bool = True, event_callback: Callable[[], Awaitable[None]] | None = None, - ) -> tuple[Task | Message | None, bool]: + ) -> tuple[Task | Message | None, bool, asyncio.Task | None]: """Processes the event stream until completion or an interruptable state is encountered. If `blocking` is False, it returns after the first event that creates a Task or Message. @@ -119,16 +119,23 @@ async def consume_and_break_on_interrupt( A tuple containing: - The current aggregated result (`Task` or `Message`) at the point of completion or interruption. - A boolean indicating whether the consumption was interrupted (`True`) or completed naturally (`False`). + - The background ``asyncio.Task`` that continues consuming events + after an interruption, or ``None`` when no background work was + spawned. **Callers must hold a strong reference** to this task + (e.g. in a ``set``) to prevent the garbage collector from + collecting it before it finishes — the event loop only keeps + weak references to tasks. Raises: BaseException: If the `EventConsumer` raises an exception during consumption. """ event_stream = consumer.consume_all() interrupted = False + bg_task: asyncio.Task | None = None async for event in event_stream: if isinstance(event, Message): self._message = event - return event, False + return event, False, None await self.task_manager.process(event) should_interrupt = False @@ -158,13 +165,13 @@ async def consume_and_break_on_interrupt( if should_interrupt: # Continue consuming the rest of the events in the background. - # TODO: We should track all outstanding tasks to ensure they eventually complete. - asyncio.create_task( # noqa: RUF006 + # The caller is responsible for tracking this task to prevent GC. + bg_task = asyncio.create_task( self._continue_consuming(event_stream, event_callback) ) interrupted = True break - return await self.task_manager.get_task(), interrupted + return await self.task_manager.get_task(), interrupted, bg_task async def _continue_consuming( self, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 067b8bb57..ec2956fa2 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -421,6 +421,7 @@ async def test_on_message_send_with_push_notification(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( final_task_result, False, + None, ) # Mock the current_result property to return the final task result @@ -520,6 +521,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( initial_task, True, # interrupted = True for non-blocking + MagicMock(spec=asyncio.Task), # background task ) # Mock the current_result property to return the final task @@ -540,7 +542,11 @@ async def mock_consume_and_break_on_interrupt( nonlocal event_callback_passed, event_callback_received event_callback_passed = event_callback is not None event_callback_received = event_callback - return initial_task, True # interrupted = True for non-blocking + return ( + initial_task, + True, + MagicMock(spec=asyncio.Task), + ) # interrupted = True for non-blocking mock_result_aggregator_instance.consume_and_break_on_interrupt = ( mock_consume_and_break_on_interrupt @@ -631,6 +637,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( final_task_result, False, + None, ) # Mock the current_result property to return the final task result @@ -689,6 +696,7 @@ async def test_on_message_send_no_result_from_aggregator(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( None, False, + None, ) from a2a.utils.errors import ServerError # Local import @@ -740,6 +748,7 @@ async def test_on_message_send_task_id_mismatch(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( mismatched_task, False, + None, ) from a2a.utils.errors import ServerError # Local import @@ -950,6 +959,7 @@ async def test_on_message_send_interrupted_flow(): mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( interrupt_task_result, True, + MagicMock(spec=asyncio.Task), # background task ) # Interrupted = True # Patch asyncio.create_task to verify _cleanup_producer is scheduled diff --git a/tests/server/tasks/test_result_aggregator.py b/tests/server/tasks/test_result_aggregator.py index bc970246b..7b29ea4c8 100644 --- a/tests/server/tasks/test_result_aggregator.py +++ b/tests/server/tasks/test_result_aggregator.py @@ -228,12 +228,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) self.assertEqual(result, sample_message) self.assertFalse(interrupted) + self.assertIsNone(bg_task) self.mock_task_manager.process.assert_not_called() # Process is not called for the Message if returned directly # _continue_consuming should not be called if it's a message interrupt # and no auth_required state. @@ -265,12 +267,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) self.assertEqual(result, auth_task) self.assertTrue(interrupted) + self.assertIsNotNone(bg_task) self.mock_task_manager.process.assert_called_once_with(auth_task) mock_create_task.assert_called_once() # Check that create_task was called # self.aggregator._continue_consuming is an AsyncMock. @@ -317,12 +321,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) self.assertEqual(result, current_task_state_after_update) self.assertTrue(interrupted) + self.assertIsNotNone(bg_task) self.mock_task_manager.process.assert_called_once_with( auth_status_update ) @@ -353,6 +359,7 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer ) @@ -360,6 +367,7 @@ async def mock_consume_generator(): # If the first event is a Message, it's returned directly. self.assertEqual(result, event1) self.assertFalse(interrupted) + self.assertIsNone(bg_task) # process() is NOT called for the Message if it's the one causing the return self.mock_task_manager.process.assert_not_called() self.mock_task_manager.get_task.assert_not_called() @@ -415,12 +423,14 @@ async def mock_consume_generator(): ( result, interrupted, + bg_task, ) = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer, blocking=False ) self.assertEqual(result, first_event) self.assertTrue(interrupted) + self.assertIsNotNone(bg_task) self.mock_task_manager.process.assert_called_once_with(first_event) mock_create_task.assert_called_once() # The background task should be created with the remaining stream @@ -468,7 +478,7 @@ async def initial_consume_generator(): mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro) # Call the main method that triggers _continue_consuming via create_task - _, _ = await self.aggregator.consume_and_break_on_interrupt( + _, _, _ = await self.aggregator.consume_and_break_on_interrupt( self.mock_event_consumer )