Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,17 @@ async def push_notification_callback() -> None:
(
result,
interrupted_or_non_blocking,
bg_consume_task,
) = await result_aggregator.consume_and_break_on_interrupt(
consumer,
blocking=blocking,
event_callback=push_notification_callback,
)

if bg_consume_task is not None:
bg_consume_task.set_name(f'continue_consuming:{task_id}')
self._track_background_task(bg_consume_task)

except Exception:
logger.exception('Agent execution failed')
producer_task.cancel()
Expand Down
17 changes: 12 additions & 5 deletions src/a2a/server/tasks/result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def consume_and_break_on_interrupt(
consumer: EventConsumer,
blocking: bool = True,
event_callback: Callable[[], Awaitable[None]] | None = None,
) -> tuple[Task | Message | None, bool]:
) -> tuple[Task | Message | None, bool, asyncio.Task | None]:
"""Processes the event stream until completion or an interruptable state is encountered.

If `blocking` is False, it returns after the first event that creates a Task or Message.
Expand All @@ -119,16 +119,23 @@ async def consume_and_break_on_interrupt(
A tuple containing:
- The current aggregated result (`Task` or `Message`) at the point of completion or interruption.
- A boolean indicating whether the consumption was interrupted (`True`) or completed naturally (`False`).
- The background ``asyncio.Task`` that continues consuming events
after an interruption, or ``None`` when no background work was
spawned. **Callers must hold a strong reference** to this task
(e.g. in a ``set``) to prevent the garbage collector from
collecting it before it finishes — the event loop only keeps
weak references to tasks.

Raises:
BaseException: If the `EventConsumer` raises an exception during consumption.
"""
event_stream = consumer.consume_all()
interrupted = False
bg_task: asyncio.Task | None = None
async for event in event_stream:
if isinstance(event, Message):
self._message = event
return event, False
return event, False, None
await self.task_manager.process(event)

should_interrupt = False
Expand Down Expand Up @@ -158,13 +165,13 @@ async def consume_and_break_on_interrupt(

if should_interrupt:
# Continue consuming the rest of the events in the background.
# TODO: We should track all outstanding tasks to ensure they eventually complete.
asyncio.create_task( # noqa: RUF006
# The caller is responsible for tracking this task to prevent GC.
bg_task = asyncio.create_task(
self._continue_consuming(event_stream, event_callback)
)
interrupted = True
break
return await self.task_manager.get_task(), interrupted
return await self.task_manager.get_task(), interrupted, bg_task

async def _continue_consuming(
self,
Expand Down
12 changes: 11 additions & 1 deletion tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ async def test_on_message_send_with_push_notification():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
None,
)

# Mock the current_result property to return the final task result
Expand Down Expand Up @@ -520,6 +521,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
initial_task,
True, # interrupted = True for non-blocking
MagicMock(spec=asyncio.Task), # background task
)

# Mock the current_result property to return the final task
Expand All @@ -540,7 +542,11 @@ async def mock_consume_and_break_on_interrupt(
nonlocal event_callback_passed, event_callback_received
event_callback_passed = event_callback is not None
event_callback_received = event_callback
return initial_task, True # interrupted = True for non-blocking
return (
initial_task,
True,
MagicMock(spec=asyncio.Task),
) # interrupted = True for non-blocking

mock_result_aggregator_instance.consume_and_break_on_interrupt = (
mock_consume_and_break_on_interrupt
Expand Down Expand Up @@ -631,6 +637,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
None,
)

# Mock the current_result property to return the final task result
Expand Down Expand Up @@ -689,6 +696,7 @@ async def test_on_message_send_no_result_from_aggregator():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
None,
False,
None,
)

from a2a.utils.errors import ServerError # Local import
Expand Down Expand Up @@ -740,6 +748,7 @@ async def test_on_message_send_task_id_mismatch():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
mismatched_task,
False,
None,
)

from a2a.utils.errors import ServerError # Local import
Expand Down Expand Up @@ -950,6 +959,7 @@ async def test_on_message_send_interrupted_flow():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
interrupt_task_result,
True,
MagicMock(spec=asyncio.Task), # background task
) # Interrupted = True

# Patch asyncio.create_task to verify _cleanup_producer is scheduled
Expand Down
12 changes: 11 additions & 1 deletion tests/server/tasks/test_result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, sample_message)
self.assertFalse(interrupted)
self.assertIsNone(bg_task)
self.mock_task_manager.process.assert_not_called() # Process is not called for the Message if returned directly
# _continue_consuming should not be called if it's a message interrupt
# and no auth_required state.
Expand Down Expand Up @@ -265,12 +267,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, auth_task)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(auth_task)
mock_create_task.assert_called_once() # Check that create_task was called
# self.aggregator._continue_consuming is an AsyncMock.
Expand Down Expand Up @@ -317,12 +321,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, current_task_state_after_update)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(
auth_status_update
)
Expand Down Expand Up @@ -353,13 +359,15 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

# If the first event is a Message, it's returned directly.
self.assertEqual(result, event1)
self.assertFalse(interrupted)
self.assertIsNone(bg_task)
# process() is NOT called for the Message if it's the one causing the return
self.mock_task_manager.process.assert_not_called()
self.mock_task_manager.get_task.assert_not_called()
Expand Down Expand Up @@ -415,12 +423,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer, blocking=False
)

self.assertEqual(result, first_event)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(first_event)
mock_create_task.assert_called_once()
# The background task should be created with the remaining stream
Expand Down Expand Up @@ -468,7 +478,7 @@ async def initial_consume_generator():
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)

# Call the main method that triggers _continue_consuming via create_task
_, _ = await self.aggregator.consume_and_break_on_interrupt(
_, _, _ = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

Expand Down
Loading