diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 9f0ca69163..57cdf8f72e 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -373,7 +373,9 @@ async def _map_a2a_stream( update = AgentResponseUpdate( contents=contents, role="assistant" if item.role == A2ARole.agent else "user", - response_id=str(getattr(item, "message_id", uuid.uuid4())), + response_id=str(getattr(item, "task_id", None) or uuid.uuid4()), + message_id=str(getattr(item, "message_id", uuid.uuid4())), + additional_properties=item.metadata, raw_representation=item, ) all_updates.append(update) @@ -508,6 +510,7 @@ def _updates_from_task_update_event( role="assistant", response_id=update_event.task_id, message_id=update_event.artifact.artifact_id, + additional_properties=update_event.artifact.metadata, raw_representation=update_event, ) ] @@ -528,6 +531,7 @@ def _updates_from_task_update_event( contents=contents, role="assistant" if message.role == A2ARole.agent else "user", response_id=update_event.task_id, + additional_properties=message.metadata, raw_representation=update_event, ) ] diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 442960a7ee..afd6e87e11 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -47,7 +47,15 @@ def __init__(self) -> None: self.resubscribe_responses: list[Any] = [] self.get_task_response: Task | None = None - def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None: + def add_message_response( + self, + message_id: str, + text: str, + role: str = "agent", + *, + task_id: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: """Add a mock Message response.""" # Create actual TextPart instance and wrap it in Part @@ -55,7 +63,11 @@ def add_message_response(self, message_id: str, text: str, role: str = "agent") # Create actual Message instance message = A2AMessage( - message_id=message_id, role=A2ARole.agent if role == "agent" else A2ARole.user, parts=[text_part] + message_id=message_id, + role=A2ARole.agent if role == "agent" else A2ARole.user, + parts=[text_part], + task_id=task_id, + metadata=metadata, ) self.responses.append(message) @@ -216,7 +228,7 @@ def test_a2a_agent_initialization_without_client_raises_error() -> None: async def test_run_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: """Test run() method with immediate Message response.""" - mock_a2a_client.add_message_response("msg-123", "Hello from agent!", "agent") + mock_a2a_client.add_message_response("msg-123", "Hello from agent!", "agent", task_id="task-100") response = await a2a_agent.run("Hello agent") @@ -224,7 +236,7 @@ async def test_run_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: M assert len(response.messages) == 1 assert response.messages[0].role == "assistant" assert response.messages[0].text == "Hello from agent!" - assert response.response_id == "msg-123" + assert response.response_id == "task-100" assert mock_a2a_client.call_count == 1 @@ -443,7 +455,9 @@ def test_prepare_message_for_a2a_empty_contents_raises_error(a2a_agent: A2AAgent async def test_run_streaming_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: """Test run(stream=True) method with immediate Message response.""" - mock_a2a_client.add_message_response("msg-stream-123", "Streaming response from agent!", "agent") + mock_a2a_client.add_message_response( + "msg-stream-123", "Streaming response from agent!", "agent", task_id="task-stream-100" + ) # Collect streaming updates updates: list[AgentResponseUpdate] = [] @@ -460,7 +474,8 @@ async def test_run_streaming_with_message_response(a2a_agent: A2AAgent, mock_a2a assert content.type == "text" assert content.text == "Streaming response from agent!" - assert updates[0].response_id == "msg-stream-123" + assert updates[0].response_id == "task-stream-100" + assert updates[0].message_id == "msg-stream-123" assert mock_a2a_client.call_count == 1 @@ -1385,3 +1400,119 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts( # endregion + + +# region response_id consistency and metadata forwarding (#5263, #5240) + + +async def test_message_response_id_uses_task_id(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: + """Test that response_id is derived from task_id, not message_id (#5263).""" + mock_a2a_client.add_message_response("msg-abc", "Hello", task_id="task-xyz") + + response = await a2a_agent.run("Hi") + + assert response.response_id == "task-xyz" + + +async def test_message_response_id_fallback_without_task_id( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that response_id falls back to a UUID when task_id is absent (#5263).""" + mock_a2a_client.add_message_response("msg-no-task", "Hello") + + response = await a2a_agent.run("Hi") + + # Should be a valid UUID string, not message_id + assert response.response_id != "msg-no-task" + from uuid import UUID + + UUID(response.response_id) # raises ValueError if not a valid UUID + + +async def test_message_metadata_forwarded_as_additional_properties( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that A2AMessage metadata is forwarded as additional_properties (#5240).""" + mock_a2a_client.add_message_response( + "msg-meta", + "Hello", + task_id="task-meta", + metadata={"custom_key": "custom_value", "priority": "high"}, + ) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hi", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].additional_properties == {"custom_key": "custom_value", "priority": "high"} + + +async def test_artifact_update_event_metadata_forwarded( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that TaskArtifactUpdateEvent metadata is forwarded (#5240).""" + task = Task(id="task-art-meta", context_id="ctx", status=TaskStatus(state=TaskState.working, message=None)) + artifact = Artifact( + artifact_id="artifact-meta", + parts=[Part(root=TextPart(text="Content"))], + metadata={"source": "tool", "version": "2"}, + ) + update_event = TaskArtifactUpdateEvent( + task_id="task-art-meta", context_id="ctx", artifact=artifact, append=False + ) + mock_a2a_client.responses.append((task, update_event)) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hello", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].additional_properties == {"source": "tool", "version": "2"} + + +async def test_status_update_event_metadata_forwarded( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that TaskStatusUpdateEvent message metadata is forwarded (#5240).""" + update_event = TaskStatusUpdateEvent( + task_id="task-status-meta", + context_id="ctx", + status=TaskStatus( + state=TaskState.working, + message=A2AMessage( + message_id=str(uuid4()), + role=A2ARole.agent, + parts=[Part(root=TextPart(text="Processing"))], + metadata={"step": "3", "progress": "75%"}, + ), + ), + final=False, + ) + task = Task(id="task-status-meta", context_id="ctx", status=TaskStatus(state=TaskState.working, message=None)) + mock_a2a_client.responses.append((task, update_event)) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hello", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].additional_properties == {"step": "3", "progress": "75%"} + + +async def test_message_id_preserved_separately_from_response_id( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that message_id is set separately from response_id (#5263).""" + mock_a2a_client.add_message_response("msg-unique", "Hello", task_id="task-parent") + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hi", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].response_id == "task-parent" + assert updates[0].message_id == "msg-unique" + + +# endregion