Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 110 additions & 9 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import httpx

from packaging.version import InvalidVersion, Version

from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client import Client, ClientConfig, Consumer
Expand All @@ -21,6 +23,8 @@
AgentInterface,
)
from a2a.utils.constants import (
PROTOCOL_VERSION_0_3,
PROTOCOL_VERSION_1_0,
PROTOCOL_VERSION_CURRENT,
VERSION_HEADER,
TransportProtocol,
Expand All @@ -33,6 +37,12 @@
GrpcTransport = None # type: ignore # pyright: ignore


try:
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
except ImportError:
CompatGrpcTransport = None # type: ignore # pyright: ignore


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -109,10 +119,102 @@ def _register_defaults(self, supported: list[str]) -> None:
'To use GrpcClient, its dependencies must be installed. '
'You can install them with \'pip install "a2a-sdk[grpc]"\''
)

def grpc_transport_producer(
card: AgentCard,
url: str,
config: ClientConfig,
interceptors: list[ClientCallInterceptor],
) -> ClientTransport:
# The interface has already been selected and passed as `url`.
# We determine its version to use the appropriate transport implementation.
interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[TransportProtocol.GRPC],
url=url,
)
version = (
interface.protocol_version
if interface
else PROTOCOL_VERSION_CURRENT
)

compat_transport = CompatGrpcTransport
if version and compat_transport is not None:
try:
v = Version(version)
if (
Version(PROTOCOL_VERSION_0_3)
<= v
< Version(PROTOCOL_VERSION_1_0)
):
return compat_transport.create(
card, url, config, interceptors
)
except InvalidVersion:
pass

grpc_transport = GrpcTransport
if grpc_transport is not None:
return grpc_transport.create(
card, url, config, interceptors
)

raise ImportError(
'GrpcTransport is not available. '
'You can install it with \'pip install "a2a-sdk[grpc]"\''
)

self.register(
TransportProtocol.GRPC,
GrpcTransport.create,
grpc_transport_producer,
)

@staticmethod
def _find_best_interface(
interfaces: list[AgentInterface],
protocol_bindings: list[str] | None = None,
url: str | None = None,
) -> AgentInterface | None:
"""Finds the best interface based on protocol version priorities."""
candidates = [
i
for i in interfaces
if (
protocol_bindings is None
or i.protocol_binding in protocol_bindings
)
and (url is None or i.url == url)
]

if not candidates:
return None

# Prefer interface with version 1.0
for i in candidates:
if i.protocol_version == PROTOCOL_VERSION_1_0:
return i

best_gt_1_0 = None
best_ge_0_3 = None
best_no_version = None

for i in candidates:
if not i.protocol_version:
if best_no_version is None:
best_no_version = i
continue

try:
v = Version(i.protocol_version)
if best_gt_1_0 is None and v > Version(PROTOCOL_VERSION_1_0):
best_gt_1_0 = i
if best_ge_0_3 is None and v >= Version(PROTOCOL_VERSION_0_3):
best_ge_0_3 = i
except InvalidVersion:
pass

return best_gt_1_0 or best_ge_0_3 or best_no_version

@classmethod
async def connect( # noqa: PLR0913
Expand Down Expand Up @@ -220,13 +322,9 @@ def create(
selected_interface = None
if self._config.use_client_preference:
for protocol_binding in client_set:
selected_interface = next(
(
si
for si in card.supported_interfaces
if si.protocol_binding == protocol_binding
),
None,
selected_interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[protocol_binding],
)
if selected_interface:
transport_protocol = protocol_binding
Expand All @@ -235,7 +333,10 @@ def create(
for supported_interface in card.supported_interfaces:
if supported_interface.protocol_binding in client_set:
transport_protocol = supported_interface.protocol_binding
selected_interface = supported_interface
selected_interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[transport_protocol],
)
break
if not transport_protocol or not selected_interface:
raise ValueError('no compatible transports found.')
Expand Down
Loading
Loading