77
88import httpx
99
10+ from packaging .version import InvalidVersion , Version
11+
1012from a2a .client .base_client import BaseClient
1113from a2a .client .card_resolver import A2ACardResolver
1214from a2a .client .client import Client , ClientConfig , Consumer
2123 AgentInterface ,
2224)
2325from a2a .utils .constants import (
26+ PROTOCOL_VERSION_0_3 ,
27+ PROTOCOL_VERSION_1_0 ,
2428 PROTOCOL_VERSION_CURRENT ,
2529 VERSION_HEADER ,
2630 TransportProtocol ,
3337 GrpcTransport = None # type: ignore # pyright: ignore
3438
3539
40+ try :
41+ from a2a .compat .v0_3 .grpc_transport import CompatGrpcTransport
42+ except ImportError :
43+ CompatGrpcTransport = None # type: ignore # pyright: ignore
44+
45+
3646logger = logging .getLogger (__name__ )
3747
3848
@@ -109,10 +119,102 @@ def _register_defaults(self, supported: list[str]) -> None:
109119 'To use GrpcClient, its dependencies must be installed. '
110120 'You can install them with \' pip install "a2a-sdk[grpc]"\' '
111121 )
122+
123+ def grpc_transport_producer (
124+ card : AgentCard ,
125+ url : str ,
126+ config : ClientConfig ,
127+ interceptors : list [ClientCallInterceptor ],
128+ ) -> ClientTransport :
129+ # The interface has already been selected and passed as `url`.
130+ # We determine its version to use the appropriate transport implementation.
131+ interface = ClientFactory ._find_best_interface (
132+ list (card .supported_interfaces ),
133+ protocol_bindings = [TransportProtocol .GRPC ],
134+ url = url ,
135+ )
136+ version = (
137+ interface .protocol_version
138+ if interface
139+ else PROTOCOL_VERSION_CURRENT
140+ )
141+
142+ compat_transport = CompatGrpcTransport
143+ if version and compat_transport is not None :
144+ try :
145+ v = Version (version )
146+ if (
147+ Version (PROTOCOL_VERSION_0_3 )
148+ <= v
149+ < Version (PROTOCOL_VERSION_1_0 )
150+ ):
151+ return compat_transport .create (
152+ card , url , config , interceptors
153+ )
154+ except InvalidVersion :
155+ pass
156+
157+ grpc_transport = GrpcTransport
158+ if grpc_transport is not None :
159+ return grpc_transport .create (
160+ card , url , config , interceptors
161+ )
162+
163+ raise ImportError (
164+ 'GrpcTransport is not available. '
165+ 'You can install it with \' pip install "a2a-sdk[grpc]"\' '
166+ )
167+
112168 self .register (
113169 TransportProtocol .GRPC ,
114- GrpcTransport .create ,
170+ grpc_transport_producer ,
171+ )
172+
173+ @staticmethod
174+ def _find_best_interface (
175+ interfaces : list [AgentInterface ],
176+ protocol_bindings : list [str ] | None = None ,
177+ url : str | None = None ,
178+ ) -> AgentInterface | None :
179+ """Finds the best interface based on protocol version priorities."""
180+ candidates = [
181+ i
182+ for i in interfaces
183+ if (
184+ protocol_bindings is None
185+ or i .protocol_binding in protocol_bindings
115186 )
187+ and (url is None or i .url == url )
188+ ]
189+
190+ if not candidates :
191+ return None
192+
193+ # Prefer interface with version 1.0
194+ for i in candidates :
195+ if i .protocol_version == PROTOCOL_VERSION_1_0 :
196+ return i
197+
198+ best_gt_1_0 = None
199+ best_ge_0_3 = None
200+ best_no_version = None
201+
202+ for i in candidates :
203+ if not i .protocol_version :
204+ if best_no_version is None :
205+ best_no_version = i
206+ continue
207+
208+ try :
209+ v = Version (i .protocol_version )
210+ if best_gt_1_0 is None and v > Version (PROTOCOL_VERSION_1_0 ):
211+ best_gt_1_0 = i
212+ if best_ge_0_3 is None and v >= Version (PROTOCOL_VERSION_0_3 ):
213+ best_ge_0_3 = i
214+ except InvalidVersion :
215+ pass
216+
217+ return best_gt_1_0 or best_ge_0_3 or best_no_version
116218
117219 @classmethod
118220 async def connect ( # noqa: PLR0913
@@ -220,13 +322,9 @@ def create(
220322 selected_interface = None
221323 if self ._config .use_client_preference :
222324 for protocol_binding in client_set :
223- selected_interface = next (
224- (
225- si
226- for si in card .supported_interfaces
227- if si .protocol_binding == protocol_binding
228- ),
229- None ,
325+ selected_interface = ClientFactory ._find_best_interface (
326+ list (card .supported_interfaces ),
327+ protocol_bindings = [protocol_binding ],
230328 )
231329 if selected_interface :
232330 transport_protocol = protocol_binding
@@ -235,7 +333,10 @@ def create(
235333 for supported_interface in card .supported_interfaces :
236334 if supported_interface .protocol_binding in client_set :
237335 transport_protocol = supported_interface .protocol_binding
238- selected_interface = supported_interface
336+ selected_interface = ClientFactory ._find_best_interface (
337+ list (card .supported_interfaces ),
338+ protocol_bindings = [transport_protocol ],
339+ )
239340 break
240341 if not transport_protocol or not selected_interface :
241342 raise ValueError ('no compatible transports found.' )
0 commit comments