Skip to content

Commit c99d3fe

Browse files
committed
feat(compat): GRPC client compatible with 0.3 servers.
1 parent 80d827a commit c99d3fe

7 files changed

Lines changed: 901 additions & 12 deletions

File tree

src/a2a/client/client_factory.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import httpx
99

10+
from packaging.version import InvalidVersion, Version
11+
1012
from a2a.client.base_client import BaseClient
1113
from a2a.client.card_resolver import A2ACardResolver
1214
from a2a.client.client import Client, ClientConfig, Consumer
@@ -21,6 +23,8 @@
2123
AgentInterface,
2224
)
2325
from a2a.utils.constants import (
26+
PROTOCOL_VERSION_0_3,
27+
PROTOCOL_VERSION_1_0,
2428
PROTOCOL_VERSION_CURRENT,
2529
VERSION_HEADER,
2630
TransportProtocol,
@@ -33,6 +37,12 @@
3337
GrpcTransport = None # type: ignore # pyright: ignore
3438

3539

40+
try:
41+
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
42+
except ImportError:
43+
CompatGrpcTransport = None # type: ignore # pyright: ignore
44+
45+
3646
logger = logging.getLogger(__name__)
3747

3848

@@ -109,10 +119,102 @@ def _register_defaults(self, supported: list[str]) -> None:
109119
'To use GrpcClient, its dependencies must be installed. '
110120
'You can install them with \'pip install "a2a-sdk[grpc]"\''
111121
)
122+
123+
def grpc_transport_producer(
124+
card: AgentCard,
125+
url: str,
126+
config: ClientConfig,
127+
interceptors: list[ClientCallInterceptor],
128+
) -> ClientTransport:
129+
# The interface has already been selected and passed as `url`.
130+
# We determine its version to use the appropriate transport implementation.
131+
interface = ClientFactory._find_best_interface(
132+
list(card.supported_interfaces),
133+
protocol_bindings=[TransportProtocol.GRPC],
134+
url=url,
135+
)
136+
version = (
137+
interface.protocol_version
138+
if interface
139+
else PROTOCOL_VERSION_CURRENT
140+
)
141+
142+
compat_transport = CompatGrpcTransport
143+
if version and compat_transport is not None:
144+
try:
145+
v = Version(version)
146+
if (
147+
Version(PROTOCOL_VERSION_0_3)
148+
<= v
149+
< Version(PROTOCOL_VERSION_1_0)
150+
):
151+
return compat_transport.create(
152+
card, url, config, interceptors
153+
)
154+
except InvalidVersion:
155+
pass
156+
157+
grpc_transport = GrpcTransport
158+
if grpc_transport is not None:
159+
return grpc_transport.create(
160+
card, url, config, interceptors
161+
)
162+
163+
raise ImportError(
164+
'GrpcTransport is not available. '
165+
'You can install it with \'pip install "a2a-sdk[grpc]"\''
166+
)
167+
112168
self.register(
113169
TransportProtocol.GRPC,
114-
GrpcTransport.create,
170+
grpc_transport_producer,
171+
)
172+
173+
@staticmethod
174+
def _find_best_interface(
175+
interfaces: list[AgentInterface],
176+
protocol_bindings: list[str] | None = None,
177+
url: str | None = None,
178+
) -> AgentInterface | None:
179+
"""Finds the best interface based on protocol version priorities."""
180+
candidates = [
181+
i
182+
for i in interfaces
183+
if (
184+
protocol_bindings is None
185+
or i.protocol_binding in protocol_bindings
115186
)
187+
and (url is None or i.url == url)
188+
]
189+
190+
if not candidates:
191+
return None
192+
193+
# Prefer interface with version 1.0
194+
for i in candidates:
195+
if i.protocol_version == PROTOCOL_VERSION_1_0:
196+
return i
197+
198+
best_gt_1_0 = None
199+
best_ge_0_3 = None
200+
best_no_version = None
201+
202+
for i in candidates:
203+
if not i.protocol_version:
204+
if best_no_version is None:
205+
best_no_version = i
206+
continue
207+
208+
try:
209+
v = Version(i.protocol_version)
210+
if best_gt_1_0 is None and v > Version(PROTOCOL_VERSION_1_0):
211+
best_gt_1_0 = i
212+
if best_ge_0_3 is None and v >= Version(PROTOCOL_VERSION_0_3):
213+
best_ge_0_3 = i
214+
except InvalidVersion:
215+
pass
216+
217+
return best_gt_1_0 or best_ge_0_3 or best_no_version
116218

117219
@classmethod
118220
async def connect( # noqa: PLR0913
@@ -220,13 +322,9 @@ def create(
220322
selected_interface = None
221323
if self._config.use_client_preference:
222324
for protocol_binding in client_set:
223-
selected_interface = next(
224-
(
225-
si
226-
for si in card.supported_interfaces
227-
if si.protocol_binding == protocol_binding
228-
),
229-
None,
325+
selected_interface = ClientFactory._find_best_interface(
326+
list(card.supported_interfaces),
327+
protocol_bindings=[protocol_binding],
230328
)
231329
if selected_interface:
232330
transport_protocol = protocol_binding
@@ -235,7 +333,10 @@ def create(
235333
for supported_interface in card.supported_interfaces:
236334
if supported_interface.protocol_binding in client_set:
237335
transport_protocol = supported_interface.protocol_binding
238-
selected_interface = supported_interface
336+
selected_interface = ClientFactory._find_best_interface(
337+
list(card.supported_interfaces),
338+
protocol_bindings=[transport_protocol],
339+
)
239340
break
240341
if not transport_protocol or not selected_interface:
241342
raise ValueError('no compatible transports found.')

0 commit comments

Comments
 (0)