Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/google/adk/a2a/executor/a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,11 @@ async def _handle_request(
await event_queue.enqueue_event(a2a_event)

# publish the task result event - this is final
agg_message = task_result_aggregator.task_status_message
if (
task_result_aggregator.task_state == TaskState.working
and task_result_aggregator.task_status_message is not None
and task_result_aggregator.task_status_message.parts
and agg_message is not None
and agg_message.parts
):
# if task is still working properly, publish the artifact update event as
# the final result according to a2a protocol.
Expand All @@ -287,7 +288,8 @@ async def _handle_request(
context_id=context.context_id,
artifact=Artifact(
artifact_id=platform_uuid.new_uuid(),
parts=task_result_aggregator.task_status_message.parts,
parts=agg_message.parts,
metadata=agg_message.metadata or None,
),
)
)
Expand All @@ -304,14 +306,22 @@ async def _handle_request(
final=True,
)
else:
# Resolve terminal state: working → completed (agent finished
# without error); other states (failed, auth_required, etc.)
# are preserved as-is.
final_state = (
TaskState.completed
if task_result_aggregator.task_state == TaskState.working
else task_result_aggregator.task_state
)
final_event = TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=task_result_aggregator.task_state,
state=final_state,
timestamp=datetime.fromtimestamp(
platform_time.get_time(), tz=timezone.utc
).isoformat(),
message=task_result_aggregator.task_status_message,
message=agg_message,
),
context_id=context.context_id,
final=True,
Expand Down
44 changes: 43 additions & 1 deletion src/google/adk/a2a/executor/task_result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from a2a.types import Message
from a2a.types import TaskState
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart

from ..experimental import a2a_experimental

Expand Down Expand Up @@ -59,9 +60,50 @@ def process_event(self, event: Event):
# always working because other state may terminate the event aggregation
# in a2a request handler
elif self._task_state == TaskState.working:
self._task_status_message = event.status.message
self._accumulate_message(event.status.message)
event.status.state = TaskState.working

def _accumulate_message(self, new_message: Message | None):
"""Accumulate content from a new message into the running result.

For delta-style streaming, successive TextPart texts are concatenated
rather than replaced. Metadata dicts are merged (later values win).
"""
if new_message is None:
return

if self._task_status_message is None:
self._task_status_message = new_message
return

# Accumulate parts
if new_message.parts:
if not self._task_status_message.parts:
self._task_status_message.parts = list(new_message.parts)
else:
for new_part in new_message.parts:
new_root = getattr(new_part, 'root', new_part)
if isinstance(new_root, TextPart):
# Concatenate into the last existing TextPart if one exists
appended = False
for existing_part in reversed(self._task_status_message.parts):
existing_root = getattr(existing_part, 'root', existing_part)
if isinstance(existing_root, TextPart):
existing_root.text += new_root.text
appended = True
break
if not appended:
self._task_status_message.parts.append(new_part)
else:
self._task_status_message.parts.append(new_part)

# Merge metadata
if new_message.metadata:
if self._task_status_message.metadata is None:
self._task_status_message.metadata = dict(new_message.metadata)
else:
self._task_status_message.metadata.update(new_message.metadata)

@property
def task_state(self) -> TaskState:
return self._task_state
Expand Down
220 changes: 210 additions & 10 deletions tests/unittests/a2a/executor/test_a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from a2a.types import Part
from a2a.types import Role
from a2a.types import TaskState
from a2a.types import TaskStatus
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from google.adk.a2a.converters.request_converter import AgentRunRequest
from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
Expand Down Expand Up @@ -143,9 +145,9 @@ async def mock_run_async(**kwargs):
final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0]
assert final_event.final == True
# The TaskResultAggregator is created with default state (working), and since no messages
# are processed, it will publish a status event with the current state
# are processed, the agent completed normally so the terminal state is completed
assert hasattr(final_event.status, "message")
assert final_event.status.state == TaskState.working
assert final_event.status.state == TaskState.completed

@pytest.mark.asyncio
async def test_execute_no_message_error(self):
Expand Down Expand Up @@ -218,9 +220,9 @@ async def mock_run_async(**kwargs):
final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0]
assert final_event.final == True
# The TaskResultAggregator is created with default state (working), and since no messages
# are processed, it will publish a status event with the current state
# are processed, the agent completed normally so the terminal state is completed
assert hasattr(final_event.status, "message")
assert final_event.status.state == TaskState.working
assert final_event.status.state == TaskState.completed

@pytest.mark.asyncio
async def test_prepare_session_new_session(self):
Expand Down Expand Up @@ -443,9 +445,9 @@ async def mock_run_async(**kwargs):
final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0]
assert final_event.final == True
# The TaskResultAggregator is created with default state (working), and since no messages
# are processed, it will publish a status event with the current state
# are processed, the agent completed normally so the terminal state is completed
assert hasattr(final_event.status, "message")
assert final_event.status.state == TaskState.working
assert final_event.status.state == TaskState.completed

@pytest.mark.asyncio
async def test_execute_with_async_callable_runner(self):
Expand Down Expand Up @@ -502,9 +504,9 @@ async def mock_run_async(**kwargs):
final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0]
assert final_event.final == True
# The TaskResultAggregator is created with default state (working), and since no messages
# are processed, it will publish a status event with the current state
# are processed, the agent completed normally so the terminal state is completed
assert hasattr(final_event.status, "message")
assert final_event.status.state == TaskState.working
assert final_event.status.state == TaskState.completed

@pytest.mark.asyncio
async def test_handle_request_integration(self):
Expand Down Expand Up @@ -580,8 +582,8 @@ async def mock_run_async(**kwargs):
assert len(final_events) >= 1
final_event = final_events[-1] # Get the last final event
assert final_event.status.message == mock_aggregator.task_status_message
# When aggregator state is working but no message, final event should be working
assert final_event.status.state == TaskState.working
# When aggregator state is working but no message, final event should be completed
assert final_event.status.state == TaskState.completed

@pytest.mark.asyncio
async def test_cancel_with_task_id(self):
Expand Down Expand Up @@ -803,6 +805,7 @@ async def test_handle_request_with_working_state_publishes_artifact_and_complete
test_message.message_id = "test-message-id"
test_message.role = Role.agent
test_message.parts = [Part(root=TextPart(text="test content"))]
test_message.metadata = None

# Setup detailed mocks
self.mock_request_converter.return_value = AgentRunRequest(
Expand Down Expand Up @@ -1072,3 +1075,200 @@ async def mock_run_async(**kwargs):
assert (
modified_a2a_event in enqueued_events
), "The modified event should have been enqueued"

# ── Regression tests for issue #5188 ──────────────────────────────────

@pytest.mark.asyncio
async def test_metadata_only_response_completes(self):
"""A response with metadata but no text parts should finalize as completed, not working."""
from a2a.types import TaskArtifactUpdateEvent

self.mock_context.task_id = "task-meta"
self.mock_context.context_id = "ctx-meta"
self.mock_context.current_task = Mock()

self.mock_request_converter.return_value = AgentRunRequest(
user_id="u",
session_id="s",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)

mock_session = Mock()
mock_session.id = "s"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
self.mock_runner._new_invocation_context.return_value = Mock()

# The agent yields one event whose converted A2A event has a message
# with metadata but NO parts (metadata-only response).
meta_message = Message(
message_id="m1",
role=Role.agent,
parts=[],
metadata={"intent": "greeting"},
)
a2a_status_event = TaskStatusUpdateEvent(
task_id="task-meta",
context_id="ctx-meta",
status=TaskStatus(state=TaskState.working, message=meta_message),
final=False,
)
self.mock_event_converter.return_value = [a2a_status_event]

adk_event = Mock(spec=Event)

async def mock_run_async(**kwargs):
yield adk_event

self.mock_runner.run_async = mock_run_async

await self.executor.execute(self.mock_context, self.mock_event_queue)

enqueued = [
c[0][0] for c in self.mock_event_queue.enqueue_event.call_args_list
]
final_events = [
e for e in enqueued
if isinstance(e, TaskStatusUpdateEvent) and e.final
]
assert len(final_events) == 1
assert final_events[0].status.state == TaskState.completed

# No artifact event should be emitted (no parts to wrap).
artifact_events = [
e for e in enqueued if isinstance(e, TaskArtifactUpdateEvent)
]
assert len(artifact_events) == 0

@pytest.mark.asyncio
async def test_streamed_text_accumulated_in_final_artifact(self):
"""Delta text chunks should be concatenated in the synthesized final artifact."""
from a2a.types import TaskArtifactUpdateEvent

self.mock_context.task_id = "task-stream"
self.mock_context.context_id = "ctx-stream"
self.mock_context.current_task = Mock()

self.mock_request_converter.return_value = AgentRunRequest(
user_id="u",
session_id="s",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
mock_session = Mock()
mock_session.id = "s"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
self.mock_runner._new_invocation_context.return_value = Mock()

chunks = ["Hello", " world", "!"]
adk_events = [Mock(spec=Event) for _ in chunks]

# Each ADK event converts to a status update with one text chunk.
call_index = 0

def event_converter(adk_ev, inv_ctx, task_id, ctx_id, converter):
nonlocal call_index
text = chunks[call_index]
call_index += 1
return [
TaskStatusUpdateEvent(
task_id=task_id,
context_id=ctx_id,
status=TaskStatus(
state=TaskState.working,
message=Message(
message_id="m",
role=Role.agent,
parts=[Part(root=TextPart(text=text))],
),
),
final=False,
)
]

self.mock_event_converter.side_effect = event_converter

async def mock_run_async(**kwargs):
for ev in adk_events:
yield ev

self.mock_runner.run_async = mock_run_async

await self.executor.execute(self.mock_context, self.mock_event_queue)

enqueued = [
c[0][0] for c in self.mock_event_queue.enqueue_event.call_args_list
]
artifacts = [
e for e in enqueued if isinstance(e, TaskArtifactUpdateEvent)
]
assert len(artifacts) == 1
assert artifacts[0].artifact.parts[0].root.text == "Hello world!"

finals = [
e for e in enqueued
if isinstance(e, TaskStatusUpdateEvent) and e.final
]
assert len(finals) == 1
assert finals[0].status.state == TaskState.completed

@pytest.mark.asyncio
async def test_metadata_propagated_to_synthesized_artifact(self):
"""Message metadata should be carried into the synthesized Artifact."""
from a2a.types import TaskArtifactUpdateEvent

self.mock_context.task_id = "task-mp"
self.mock_context.context_id = "ctx-mp"
self.mock_context.current_task = Mock()

self.mock_request_converter.return_value = AgentRunRequest(
user_id="u",
session_id="s",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
mock_session = Mock()
mock_session.id = "s"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
self.mock_runner._new_invocation_context.return_value = Mock()

msg_with_meta = Message(
message_id="m1",
role=Role.agent,
parts=[Part(root=TextPart(text="result"))],
metadata={"source": "agent-x", "confidence": "0.95"},
)
a2a_event = TaskStatusUpdateEvent(
task_id="task-mp",
context_id="ctx-mp",
status=TaskStatus(state=TaskState.working, message=msg_with_meta),
final=False,
)
self.mock_event_converter.return_value = [a2a_event]

adk_event = Mock(spec=Event)

async def mock_run_async(**kwargs):
yield adk_event

self.mock_runner.run_async = mock_run_async

await self.executor.execute(self.mock_context, self.mock_event_queue)

enqueued = [
c[0][0] for c in self.mock_event_queue.enqueue_event.call_args_list
]
artifacts = [
e for e in enqueued if isinstance(e, TaskArtifactUpdateEvent)
]
assert len(artifacts) == 1
assert artifacts[0].artifact.metadata == {
"source": "agent-x",
"confidence": "0.95",
}
Loading