diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 30006568..6a67d19e 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -7,6 +7,8 @@ import httpx +from packaging.version import InvalidVersion, Version + from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver from a2a.client.client import Client, ClientConfig, Consumer @@ -21,6 +23,8 @@ AgentInterface, ) from a2a.utils.constants import ( + PROTOCOL_VERSION_0_3, + PROTOCOL_VERSION_1_0, PROTOCOL_VERSION_CURRENT, VERSION_HEADER, TransportProtocol, @@ -33,6 +37,12 @@ GrpcTransport = None # type: ignore # pyright: ignore +try: + from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport +except ImportError: + CompatGrpcTransport = None # type: ignore # pyright: ignore + + logger = logging.getLogger(__name__) @@ -109,10 +119,102 @@ def _register_defaults(self, supported: list[str]) -> None: 'To use GrpcClient, its dependencies must be installed. ' 'You can install them with \'pip install "a2a-sdk[grpc]"\'' ) + + def grpc_transport_producer( + card: AgentCard, + url: str, + config: ClientConfig, + interceptors: list[ClientCallInterceptor], + ) -> ClientTransport: + # The interface has already been selected and passed as `url`. + # We determine its version to use the appropriate transport implementation. + interface = ClientFactory._find_best_interface( + list(card.supported_interfaces), + protocol_bindings=[TransportProtocol.GRPC], + url=url, + ) + version = ( + interface.protocol_version + if interface + else PROTOCOL_VERSION_CURRENT + ) + + compat_transport = CompatGrpcTransport + if version and compat_transport is not None: + try: + v = Version(version) + if ( + Version(PROTOCOL_VERSION_0_3) + <= v + < Version(PROTOCOL_VERSION_1_0) + ): + return compat_transport.create( + card, url, config, interceptors + ) + except InvalidVersion: + pass + + grpc_transport = GrpcTransport + if grpc_transport is not None: + return grpc_transport.create( + card, url, config, interceptors + ) + + raise ImportError( + 'GrpcTransport is not available. ' + 'You can install it with \'pip install "a2a-sdk[grpc]"\'' + ) + self.register( TransportProtocol.GRPC, - GrpcTransport.create, + grpc_transport_producer, + ) + + @staticmethod + def _find_best_interface( + interfaces: list[AgentInterface], + protocol_bindings: list[str] | None = None, + url: str | None = None, + ) -> AgentInterface | None: + """Finds the best interface based on protocol version priorities.""" + candidates = [ + i + for i in interfaces + if ( + protocol_bindings is None + or i.protocol_binding in protocol_bindings ) + and (url is None or i.url == url) + ] + + if not candidates: + return None + + # Prefer interface with version 1.0 + for i in candidates: + if i.protocol_version == PROTOCOL_VERSION_1_0: + return i + + best_gt_1_0 = None + best_ge_0_3 = None + best_no_version = None + + for i in candidates: + if not i.protocol_version: + if best_no_version is None: + best_no_version = i + continue + + try: + v = Version(i.protocol_version) + if best_gt_1_0 is None and v > Version(PROTOCOL_VERSION_1_0): + best_gt_1_0 = i + if best_ge_0_3 is None and v >= Version(PROTOCOL_VERSION_0_3): + best_ge_0_3 = i + except InvalidVersion: + pass + + return best_gt_1_0 or best_ge_0_3 or best_no_version @classmethod async def connect( # noqa: PLR0913 @@ -220,13 +322,9 @@ def create( selected_interface = None if self._config.use_client_preference: for protocol_binding in client_set: - selected_interface = next( - ( - si - for si in card.supported_interfaces - if si.protocol_binding == protocol_binding - ), - None, + selected_interface = ClientFactory._find_best_interface( + list(card.supported_interfaces), + protocol_bindings=[protocol_binding], ) if selected_interface: transport_protocol = protocol_binding @@ -235,7 +333,10 @@ def create( for supported_interface in card.supported_interfaces: if supported_interface.protocol_binding in client_set: transport_protocol = supported_interface.protocol_binding - selected_interface = supported_interface + selected_interface = ClientFactory._find_best_interface( + list(card.supported_interfaces), + protocol_bindings=[transport_protocol], + ) break if not transport_protocol or not selected_interface: raise ValueError('no compatible transports found.') diff --git a/src/a2a/compat/v0_3/grpc_transport.py b/src/a2a/compat/v0_3/grpc_transport.py new file mode 100644 index 00000000..b37a704b --- /dev/null +++ b/src/a2a/compat/v0_3/grpc_transport.py @@ -0,0 +1,394 @@ +import logging + +from collections.abc import AsyncGenerator, Callable +from functools import wraps +from typing import Any, NoReturn + +from a2a.client.errors import A2AClientError, A2AClientTimeoutError +from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP + + +try: + import grpc # type: ignore[reportMissingModuleSource] +except ImportError as e: + raise ImportError( + 'A2AGrpcClient requires grpcio and grpcio-tools to be installed. ' + 'Install with: ' + "'pip install a2a-sdk[grpc]'" + ) from e + + +from a2a.client.client import ClientConfig +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.optionals import Channel +from a2a.client.transports.base import ClientTransport +from a2a.compat.v0_3 import ( + a2a_v0_3_pb2, + a2a_v0_3_pb2_grpc, + conversions, + proto_utils, +) +from a2a.compat.v0_3 import ( + types as types_v03, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.types import a2a_pb2 +from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + +_A2A_ERROR_NAME_TO_CLS = { + error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP +} + + +def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise A2AClientTimeoutError('Client Request timed out') from e + + details = e.details() + if isinstance(details, str) and ': ' in details: + error_type_name, error_message = details.split(': ', 1) + exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name) + if exception_cls: + raise exception_cls(error_message) from e + raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e + + +def _handle_grpc_exception(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return await func(*args, **kwargs) + except grpc.aio.AioRpcError as e: + _map_grpc_error(e) + + return wrapper + + +def _handle_grpc_stream_exception( + func: Callable[..., Any], +) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + async for item in func(*args, **kwargs): + yield item + except grpc.aio.AioRpcError as e: + _map_grpc_error(e) + + return wrapper + + +@trace_class(kind=SpanKind.CLIENT) +class CompatGrpcTransport(ClientTransport): + """A backward compatible gRPC transport for A2A v0.3.""" + + def __init__( + self, + channel: Channel, + agent_card: a2a_pb2.AgentCard | None, + extensions: list[str] | None = None, + ): + """Initializes the CompatGrpcTransport.""" + self.agent_card = agent_card + self.channel = channel + self.stub = a2a_v0_3_pb2_grpc.A2AServiceStub(channel) + self.extensions = extensions + + @classmethod + def create( + cls, + card: a2a_pb2.AgentCard, + url: str, + config: ClientConfig, + interceptors: list[ClientCallInterceptor], + ) -> 'CompatGrpcTransport': + """Creates a gRPC transport for the A2A client.""" + if config.grpc_channel_factory is None: + raise ValueError('grpc_channel_factory is required when using gRPC') + return cls(config.grpc_channel_factory(url), card, config.extensions) + + @_handle_grpc_exception + async def send_message( + self, + request: a2a_pb2.SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.SendMessageResponse: + """Sends a non-streaming message request to the agent (v0.3).""" + req_v03 = conversions.to_compat_send_message_request( + request, request_id=0 + ) + req_proto = a2a_v0_3_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(req_v03.params.message), + configuration=proto_utils.ToProto.message_send_configuration( + req_v03.params.configuration + ), + metadata=proto_utils.ToProto.metadata(req_v03.params.metadata), + ) + + resp_proto = await self.stub.SendMessage( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + + which = resp_proto.WhichOneof('payload') + if which == 'task': + return a2a_pb2.SendMessageResponse( + task=conversions.to_core_task( + proto_utils.FromProto.task(resp_proto.task) + ) + ) + if which == 'message': + return a2a_pb2.SendMessageResponse( + message=conversions.to_core_message( + proto_utils.FromProto.message(resp_proto.message) + ) + ) + return a2a_pb2.SendMessageResponse() + + @_handle_grpc_stream_exception + async def send_message_streaming( + self, + request: a2a_pb2.SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[a2a_pb2.StreamResponse]: + """Sends a streaming message request to the agent (v0.3).""" + req_v03 = conversions.to_compat_send_message_request( + request, request_id=0 + ) + req_proto = a2a_v0_3_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(req_v03.params.message), + configuration=proto_utils.ToProto.message_send_configuration( + req_v03.params.configuration + ), + metadata=proto_utils.ToProto.metadata(req_v03.params.metadata), + ) + + stream = self.stub.SendStreamingMessage( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: # type: ignore[attr-defined] + break + yield conversions.to_core_stream_response( + types_v03.SendStreamingMessageSuccessResponse( + result=proto_utils.FromProto.stream_response(response) + ) + ) + + @_handle_grpc_stream_exception + async def subscribe( + self, + request: a2a_pb2.SubscribeToTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[a2a_pb2.StreamResponse]: + """Reconnects to get task updates (v0.3).""" + req_proto = a2a_v0_3_pb2.TaskSubscriptionRequest( + name=f'tasks/{request.id}' + ) + + stream = self.stub.TaskSubscription( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: # type: ignore[attr-defined] + break + yield conversions.to_core_stream_response( + types_v03.SendStreamingMessageSuccessResponse( + result=proto_utils.FromProto.stream_response(response) + ) + ) + + @_handle_grpc_exception + async def get_task( + self, + request: a2a_pb2.GetTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.Task: + """Retrieves the current state and history of a specific task (v0.3).""" + req_proto = a2a_v0_3_pb2.GetTaskRequest( + name=f'tasks/{request.id}', + history_length=request.history_length, + ) + resp_proto = await self.stub.GetTask( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + return conversions.to_core_task(proto_utils.FromProto.task(resp_proto)) + + @_handle_grpc_exception + async def list_tasks( + self, + request: a2a_pb2.ListTasksRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.ListTasksResponse: + """Retrieves tasks for an agent (v0.3 - NOT SUPPORTED in v0.3).""" + # v0.3 proto doesn't have ListTasks. + raise NotImplementedError( + 'ListTasks is not supported in A2A v0.3 gRPC.' + ) + + @_handle_grpc_exception + async def cancel_task( + self, + request: a2a_pb2.CancelTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.Task: + """Requests the agent to cancel a specific task (v0.3).""" + req_proto = a2a_v0_3_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + resp_proto = await self.stub.CancelTask( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + return conversions.to_core_task(proto_utils.FromProto.task(resp_proto)) + + @_handle_grpc_exception + async def create_task_push_notification_config( + self, + request: a2a_pb2.CreateTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.TaskPushNotificationConfig: + """Sets or updates the push notification configuration (v0.3).""" + req_v03 = ( + conversions.to_compat_create_task_push_notification_config_request( + request, request_id=0 + ) + ) + req_proto = a2a_v0_3_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.task_id}', + config_id=req_v03.params.push_notification_config.id, + config=proto_utils.ToProto.task_push_notification_config( + req_v03.params + ), + ) + resp_proto = await self.stub.CreateTaskPushNotificationConfig( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + return conversions.to_core_task_push_notification_config( + proto_utils.FromProto.task_push_notification_config(resp_proto) + ) + + @_handle_grpc_exception + async def get_task_push_notification_config( + self, + request: a2a_pb2.GetTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.TaskPushNotificationConfig: + """Retrieves the push notification configuration (v0.3).""" + req_proto = a2a_v0_3_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.task_id}/pushNotificationConfigs/{request.id}' + ) + resp_proto = await self.stub.GetTaskPushNotificationConfig( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + return conversions.to_core_task_push_notification_config( + proto_utils.FromProto.task_push_notification_config(resp_proto) + ) + + @_handle_grpc_exception + async def list_task_push_notification_configs( + self, + request: a2a_pb2.ListTaskPushNotificationConfigsRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> a2a_pb2.ListTaskPushNotificationConfigsResponse: + """Lists push notification configurations for a specific task (v0.3).""" + req_proto = a2a_v0_3_pb2.ListTaskPushNotificationConfigRequest( + parent=f'tasks/{request.task_id}' + ) + resp_proto = await self.stub.ListTaskPushNotificationConfig( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + return conversions.to_core_list_task_push_notification_config_response( + proto_utils.FromProto.list_task_push_notification_config_response( + resp_proto + ) + ) + + @_handle_grpc_exception + async def delete_task_push_notification_config( + self, + request: a2a_pb2.DeleteTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> None: + """Deletes the push notification configuration (v0.3).""" + req_proto = a2a_v0_3_pb2.DeleteTaskPushNotificationConfigRequest( + name=f'tasks/{request.task_id}/pushNotificationConfigs/{request.id}' + ) + await self.stub.DeleteTaskPushNotificationConfig( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + + @_handle_grpc_exception + async def get_extended_agent_card( + self, + request: a2a_pb2.GetExtendedAgentCardRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + signature_verifier: Callable[[a2a_pb2.AgentCard], None] | None = None, + ) -> a2a_pb2.AgentCard: + """Retrieves the agent's card (v0.3).""" + req_proto = a2a_v0_3_pb2.GetAgentCardRequest() + resp_proto = await self.stub.GetAgentCard( + req_proto, + metadata=self._get_grpc_metadata(extensions), + ) + card = conversions.to_core_agent_card( + proto_utils.FromProto.agent_card(resp_proto) + ) + + if signature_verifier: + signature_verifier(card) + + self.agent_card = card + return card + + async def close(self) -> None: + """Closes the gRPC channel.""" + await self.channel.close() + + def _get_grpc_metadata( + self, + extensions: list[str] | None = None, + ) -> list[tuple[str, str]]: + """Creates gRPC metadata for extensions.""" + metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_0_3)] + + extensions_to_use = extensions or self.extensions + if extensions_to_use: + metadata.append( + (HTTP_EXTENSION_HEADER.lower(), ','.join(extensions_to_use)) + ) + + return metadata diff --git a/src/a2a/compat/v0_3/proto_utils.py b/src/a2a/compat/v0_3/proto_utils.py index 61fa76cd..d9c5688d 100644 --- a/src/a2a/compat/v0_3/proto_utils.py +++ b/src/a2a/compat/v0_3/proto_utils.py @@ -1062,6 +1062,20 @@ def stream_response( return cls.task_artifact_update_event(response.artifact_update) raise ValueError('Unsupported StreamResponse type') + @classmethod + def list_task_push_notification_config_response( + cls, response: a2a_pb2.ListTaskPushNotificationConfigResponse + ) -> types.ListTaskPushNotificationConfigResponse: + return types.ListTaskPushNotificationConfigResponse( + root=types.ListTaskPushNotificationConfigSuccessResponse( + result=[ + cls.task_push_notification_config(c) + for c in response.configs + ], + id=None, + ) + ) + @classmethod def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: return types.AgentSkill( diff --git a/src/a2a/utils/constants.py b/src/a2a/utils/constants.py index 65d6598f..6cee2a05 100644 --- a/src/a2a/utils/constants.py +++ b/src/a2a/utils/constants.py @@ -25,4 +25,5 @@ class TransportProtocol(str, Enum): VERSION_HEADER = 'A2A-Version' PROTOCOL_VERSION_1_0 = '1.0' +PROTOCOL_VERSION_0_3 = '0.3' PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_1_0 diff --git a/tests/client/test_client_factory_grpc.py b/tests/client/test_client_factory_grpc.py new file mode 100644 index 00000000..1e756324 --- /dev/null +++ b/tests/client/test_client_factory_grpc.py @@ -0,0 +1,175 @@ +"""Tests for GRPC transport selection in ClientFactory.""" + +from unittest.mock import MagicMock, patch +import pytest + +from a2a.client import ClientConfig, ClientFactory +from a2a.types.a2a_pb2 import AgentCard, AgentInterface, AgentCapabilities +from a2a.utils.constants import TransportProtocol + + +@pytest.fixture +def grpc_agent_card() -> AgentCard: + """Provides an AgentCard with GRPC interfaces for tests.""" + return AgentCard( + supported_interfaces=[], + capabilities=AgentCapabilities(), + skills=[], + default_input_modes=[], + default_output_modes=[], + name='GRPC Agent', + version='1.0.0', + description='Test agent', + ) + + +def test_grpc_priority_1_0(grpc_agent_card): + """Verify that protocol version 1.0 has the highest priority and uses GrpcTransport.""" + grpc_agent_card.supported_interfaces.extend( + [ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url03', + protocol_version='0.3', + ), + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url11', + protocol_version='1.1', + ), + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url10', + protocol_version='1.0', + ), + ] + ) + + config = ClientConfig( + supported_protocol_bindings=[TransportProtocol.GRPC], + grpc_channel_factory=MagicMock(), + ) + + # We patch GrpcTransport and CompatGrpcTransport in the client_factory module + with ( + patch('a2a.client.client_factory.GrpcTransport') as mock_grpc, + patch('a2a.client.client_factory.CompatGrpcTransport') as mock_compat, + ): + factory = ClientFactory(config) + factory.create(grpc_agent_card) + + # Priority 1: 1.0 -> GrpcTransport + mock_grpc.create.assert_called_once_with( + grpc_agent_card, 'url10', config, [] + ) + mock_compat.create.assert_not_called() + + +def test_grpc_priority_gt_1_0(grpc_agent_card): + """Verify that protocol version > 1.0 uses GrpcTransport (first one found).""" + grpc_agent_card.supported_interfaces.extend( + [ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url03', + protocol_version='0.3', + ), + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url11', + protocol_version='1.1', + ), + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url12', + protocol_version='1.2', + ), + ] + ) + + config = ClientConfig( + supported_protocol_bindings=[TransportProtocol.GRPC], + grpc_channel_factory=MagicMock(), + ) + + with ( + patch('a2a.client.client_factory.GrpcTransport') as mock_grpc, + patch('a2a.client.client_factory.CompatGrpcTransport') as mock_compat, + ): + factory = ClientFactory(config) + factory.create(grpc_agent_card) + + # Priority 2: > 1.0 -> GrpcTransport (first matching is 1.1) + mock_grpc.create.assert_called_once_with( + grpc_agent_card, 'url11', config, [] + ) + mock_compat.create.assert_not_called() + + +def test_grpc_priority_lt_0_3_raises_value_error(grpc_agent_card): + """Verify that if the only available interface has version < 0.3, it raises a ValueError.""" + grpc_agent_card.supported_interfaces.extend( + [ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url02', + protocol_version='0.2', + ), + ] + ) + + config = ClientConfig( + supported_protocol_bindings=[TransportProtocol.GRPC], + grpc_channel_factory=MagicMock(), + ) + + factory = ClientFactory(config) + with pytest.raises(ValueError, match='no compatible transports found'): + factory.create(grpc_agent_card) + + +def test_grpc_invalid_version_raises_value_error(grpc_agent_card): + """Verify that if only an invalid version is available, it raises a ValueError (it's ignored).""" + grpc_agent_card.supported_interfaces.extend( + [ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url_invalid', + protocol_version='invalid_version_string', + ), + ] + ) + + config = ClientConfig( + supported_protocol_bindings=[TransportProtocol.GRPC], + grpc_channel_factory=MagicMock(), + ) + + factory = ClientFactory(config) + with pytest.raises(ValueError, match='no compatible transports found'): + factory.create(grpc_agent_card) + + +def test_grpc_unspecified_version_uses_grpc_transport(grpc_agent_card): + """Verify that if no version is specified, it defaults to GrpcTransport.""" + grpc_agent_card.supported_interfaces.extend( + [ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='url_no_version', + ), + ] + ) + + config = ClientConfig( + supported_protocol_bindings=[TransportProtocol.GRPC], + grpc_channel_factory=MagicMock(), + ) + + with patch('a2a.client.client_factory.GrpcTransport') as mock_grpc: + factory = ClientFactory(config) + factory.create(grpc_agent_card) + + mock_grpc.create.assert_called_once_with( + grpc_agent_card, 'url_no_version', config, [] + ) diff --git a/tests/integration/cross_version/client_server/client_1_0.py b/tests/integration/cross_version/client_server/client_1_0.py new file mode 100644 index 00000000..264b53c6 --- /dev/null +++ b/tests/integration/cross_version/client_server/client_1_0.py @@ -0,0 +1,188 @@ +import argparse +import asyncio +import grpc +import httpx +import sys +from uuid import uuid4 + +from a2a.client import ClientFactory, ClientConfig +from a2a.utils import TransportProtocol +from a2a.types import ( + Message, + Part, + Role, + GetTaskRequest, + CancelTaskRequest, + SubscribeToTaskRequest, + GetExtendedAgentCardRequest, +) + + +async def test_send_message_stream(client): + print('Testing send_message (streaming)...') + msg = Message( + role=Role.ROLE_USER, + message_id=f'stream-{uuid4()}', + parts=[Part(text='stream')], + metadata={'test_key': 'test_value'}, + ) + events = [] + + async for event in client.send_message(request=msg): + events.append(event) + break + + assert len(events) > 0, 'Expected at least one event' + first_event = events[0] + + # In v1.0 SDK, send_message returns tuple[StreamResponse, Task | None] + stream_response = first_event[0] + + # Try to find task_id in the oneof fields of StreamResponse + task_id = 'unknown' + if stream_response.HasField('task'): + task_id = stream_response.task.id + elif stream_response.HasField('message'): + task_id = stream_response.message.task_id + elif stream_response.HasField('status_update'): + task_id = stream_response.status_update.task_id + elif stream_response.HasField('artifact_update'): + task_id = stream_response.artifact_update.task_id + + print(f'Success: send_message (streaming) passed. Task ID: {task_id}') + return task_id + + +async def test_send_message_sync(url, protocol_enum): + print('Testing send_message (synchronous)...') + config = ClientConfig() + config.httpx_client = httpx.AsyncClient(timeout=30.0) + config.grpc_channel_factory = grpc.aio.insecure_channel + config.supported_protocol_bindings = [protocol_enum] + config.streaming = False + + client = await ClientFactory.connect(url, client_config=config) + msg = Message( + role=Role.ROLE_USER, + message_id=f'sync-{uuid4()}', + parts=[Part(text='sync')], + metadata={'test_key': 'test_value'}, + ) + + async for event in client.send_message(request=msg): + assert event is not None + stream_response = event[0] + + # In v1.0, check task status in StreamResponse + if stream_response.HasField('task'): + task = stream_response.task + if task.status.state == 3: # TASK_STATE_COMPLETED + metadata = dict(task.status.message.metadata) + assert metadata.get('response_key') == 'response_value', ( + f'Missing response metadata: {metadata}' + ) + elif stream_response.HasField('status_update'): + status_update = stream_response.status_update + if status_update.status.state == 3: # TASK_STATE_COMPLETED + metadata = dict(status_update.status.message.metadata) + assert metadata.get('response_key') == 'response_value', ( + f'Missing response metadata: {metadata}' + ) + break + + print(f'Success: send_message (synchronous) passed.') + + +async def test_get_task(client, task_id): + print(f'Testing get_task ({task_id})...') + task = await client.get_task(request=GetTaskRequest(id=task_id)) + assert task.id == task_id + print('Success: get_task passed.') + + +async def test_cancel_task(client, task_id): + print(f'Testing cancel_task ({task_id})...') + await client.cancel_task(request=CancelTaskRequest(id=task_id)) + print('Success: cancel_task passed.') + + +async def test_subscribe(client, task_id): + print(f'Testing subscribe ({task_id})...') + async for event in client.subscribe( + request=SubscribeToTaskRequest(id=task_id) + ): + print(f'Received event: {event}') + break + print('Success: subscribe passed.') + + +async def test_get_extended_agent_card(client): + print('Testing get_extended_agent_card...') + card = await client.get_extended_agent_card( + request=GetExtendedAgentCardRequest() + ) + assert card is not None + print(f'Success: get_extended_agent_card passed.') + + +async def run_client(url: str, protocol: str): + protocol_enum_map = { + 'jsonrpc': TransportProtocol.JSONRPC, + 'rest': TransportProtocol.HTTP_JSON, + 'grpc': TransportProtocol.GRPC, + } + protocol_enum = protocol_enum_map[protocol] + + config = ClientConfig() + config.httpx_client = httpx.AsyncClient(timeout=30.0) + config.grpc_channel_factory = grpc.aio.insecure_channel + config.supported_protocol_bindings = [protocol_enum] + config.streaming = True + + client = await ClientFactory.connect(url, client_config=config) + + # 1. Get Extended Agent Card + await test_get_extended_agent_card(client) + + # 2. Send Streaming Message + task_id = await test_send_message_stream(client) + + # 3. Get Task + await test_get_task(client, task_id) + + # 4. Subscribe to Task + await test_subscribe(client, task_id) + + # 5. Cancel Task + await test_cancel_task(client, task_id) + + # 6. Send Sync Message + await test_send_message_sync(url, protocol_enum) + + +def main(): + print('Starting client_1_0...') + + parser = argparse.ArgumentParser() + parser.add_argument('--url', type=str, required=True) + parser.add_argument('--protocols', type=str, nargs='+', required=True) + args = parser.parse_args() + + failed = False + for protocol in args.protocols: + print(f'\n=== Testing protocol: {protocol} ===') + try: + asyncio.run(run_client(args.url, protocol)) + except Exception as e: + import traceback + + traceback.print_exc() + print(f'FAILED protocol {protocol}: {e}') + failed = True + + if failed: + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/tests/integration/cross_version/client_server/test_client_server.py b/tests/integration/cross_version/client_server/test_client_server.py index e4a835c0..df6749a5 100644 --- a/tests/integration/cross_version/client_server/test_client_server.py +++ b/tests/integration/cross_version/client_server/test_client_server.py @@ -178,24 +178,40 @@ def running_servers(): @pytest.mark.timeout(10) @pytest.mark.parametrize( - 'server_script, client_script, client_deps', + 'server_script, client_script, client_deps, protocols', [ # Run 0.3 Server <-> 0.3 Client ( 'server_0_3.py', 'client_0_3.py', ['--with', 'a2a-sdk[grpc]==0.3.24', '--no-project'], + ['grpc', 'jsonrpc', 'rest'], ), # Run 1.0 Server <-> 0.3 Client ( 'server_1_0.py', 'client_0_3.py', ['--with', 'a2a-sdk[grpc]==0.3.24', '--no-project'], + ['grpc'], + ), + # Run 1.0 Server <-> 1.0 Client + ( + 'server_1_0.py', + 'client_1_0.py', + [], + ['grpc', 'jsonrpc', 'rest'], + ), + # Run 0.3 Server <-> 1.0 Client + ( + 'server_0_3.py', + 'client_1_0.py', + [], + ['grpc'], ), ], ) def test_cross_version( - running_servers, server_script, client_script, client_deps + running_servers, server_script, client_script, client_deps, protocols ): http_port = running_servers[server_script] uv_path = running_servers['uv_path'] @@ -210,8 +226,8 @@ def test_cross_version( '--url', card_url, '--protocols', - 'grpc', # "rest", "grpc" ] + + protocols ) client_result = subprocess.Popen(