Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
0479790
feat: Add utilities for managing service parameters and A2A extensions.
guglielmo-san Mar 6, 2026
9e5761a
wip
guglielmo-san Mar 6, 2026
abf24a1
feat: Add utilities for managing service parameters and A2A extensions.
guglielmo-san Mar 6, 2026
506598c
wip
guglielmo-san Mar 6, 2026
a86c2f4
Merge branch 'guglielmoc/refactor_base_client' of https://github.com/…
guglielmo-san Mar 6, 2026
b0d41b9
wip refactoring
guglielmo-san Mar 6, 2026
eae38e9
fix tests
guglielmo-san Mar 6, 2026
4c23416
refactor: use `ClientCallContext` for HTTP arguments in stream reques…
guglielmo-san Mar 7, 2026
b0f2033
Refactor transport request methods to use explicit `json` and `params…
guglielmo-san Mar 7, 2026
11eecb9
refactor: Extract common HTTP argument parsing logic into a shared he…
guglielmo-san Mar 8, 2026
6186a9e
refactor: qualify ParseDict call with json_format module
guglielmo-san Mar 9, 2026
729f8b4
Merge branch '1.0-dev' into guglielmoc/refactor_base_client
guglielmo-san Mar 9, 2026
3eced82
refactor: Migrate gRPC metadata handling from direct extensions param…
guglielmo-san Mar 9, 2026
816c512
refactor: Remove `extensions` handling from `grpc_transport` by utili…
guglielmo-san Mar 9, 2026
17302b2
style: remove trailing comma from agent_card type hint in `CompatGrpc…
guglielmo-san Mar 9, 2026
36e818a
refactor: remove query parameter conversion utilities from REST trans…
guglielmo-san Mar 9, 2026
5b3d711
refactor: reformat `CompatGrpcTransport` constructor parameters for c…
guglielmo-san Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 39 additions & 71 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from collections.abc import AsyncGenerator, AsyncIterator, Callable
from typing import Any

from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.types.a2a_pb2 import (
AgentCard,
Expand All @@ -23,8 +21,6 @@
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
Message,
SendMessageConfiguration,
SendMessageRequest,
StreamResponse,
SubscribeToTaskRequest,
Expand All @@ -51,12 +47,9 @@ def __init__(

async def send_message(
self,
request: Message,
request: SendMessageRequest,
*,
configuration: SendMessageConfiguration | None = None,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Sends a message to the agent.

Expand All @@ -66,35 +59,15 @@ async def send_message(

Args:
request: The message to send to the agent.
configuration: Optional per-call overrides for message sending behavior.
context: The client call context.
request_metadata: Extensions Metadata attached to the request.
extensions: List of extensions to be activated.
context: Optional client call context.

Yields:
An async iterator of `ClientEvent`
"""
config = SendMessageConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
push_notification_config=(
self._config.push_notification_configs[0]
if self._config.push_notification_configs
else None
),
)

if configuration:
config.MergeFrom(configuration)
config.blocking = configuration.blocking

send_message_request = SendMessageRequest(
message=request, configuration=config, metadata=request_metadata
)

self._apply_client_config(request)
if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
send_message_request, context=context, extensions=extensions
request, context=context
)

# In non-streaming case we convert to a StreamResponse so that the
Expand All @@ -116,11 +89,29 @@ async def send_message(
return

stream = self._transport.send_message_streaming(
send_message_request, context=context, extensions=extensions
request, context=context
)
async for client_event in self._process_stream(stream):
yield client_event

def _apply_client_config(self, request: SendMessageRequest) -> None:
if not request.configuration.blocking and self._config.polling:
request.configuration.blocking = not self._config.polling
if (
not request.configuration.HasField('push_notification_config')
and self._config.push_notification_configs
):
request.configuration.push_notification_config.CopyFrom(
self._config.push_notification_configs[0]
)
if (
not request.configuration.accepted_output_modes
and self._config.accepted_output_modes
):
request.configuration.accepted_output_modes.extend(
self._config.accepted_output_modes
)

async def _process_stream(
self, stream: AsyncIterator[StreamResponse]
) -> AsyncGenerator[ClientEvent]:
Expand All @@ -147,21 +138,17 @@ async def get_task(
request: GetTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task.

Args:
request: The `GetTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `Task` object representing the current state of the task.
"""
return await self._transport.get_task(
request, context=context, extensions=extensions
)
return await self._transport.get_task(request, context=context)

async def list_tasks(
self,
Expand All @@ -177,118 +164,104 @@ async def cancel_task(
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task.

Args:
request: The `CancelTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `Task` object containing the updated task status.
"""
return await self._transport.cancel_task(
request, context=context, extensions=extensions
)
return await self._transport.cancel_task(request, context=context)

async def create_task_push_notification_config(
self,
request: CreateTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task.

Args:
request: The `TaskPushNotificationConfig` object with the new configuration.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
The created or updated `TaskPushNotificationConfig` object.
"""
return await self._transport.create_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def get_task_push_notification_config(
self,
request: GetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task.

Args:
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `TaskPushNotificationConfig` object containing the configuration.
"""
return await self._transport.get_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def list_task_push_notification_configs(
self,
request: ListTaskPushNotificationConfigsRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task.

Args:
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `ListTaskPushNotificationConfigsResponse` object.
"""
return await self._transport.list_task_push_notification_configs(
request, context=context, extensions=extensions
request, context=context
)

async def delete_task_push_notification_config(
self,
request: DeleteTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task.

Args:
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.
"""
await self._transport.delete_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def subscribe(
self,
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream.

This is only available if both the client and server support streaming.

Args:
request: Parameters to identify the task to resubscribe to.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Yields:
An async iterator of `ClientEvent` objects.
Expand All @@ -304,9 +277,7 @@ async def subscribe(
# Note: resubscribe can only be called on an existing task. As such,
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
stream = self._transport.subscribe(
request, context=context, extensions=extensions
)
stream = self._transport.subscribe(request, context=context)
async for client_event in self._process_stream(stream):
yield client_event

Expand All @@ -315,7 +286,6 @@ async def get_extended_agent_card(
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.
Expand All @@ -325,8 +295,7 @@ async def get_extended_agent_card(

Args:
request: The `GetExtendedAgentCardRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
Expand All @@ -335,7 +304,6 @@ async def get_extended_agent_card(
card = await self._transport.get_extended_agent_card(
request,
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
Expand Down
19 changes: 2 additions & 17 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
Message,
PushNotificationConfig,
SendMessageConfiguration,
SendMessageRequest,
StreamResponse,
SubscribeToTaskRequest,
Task,
Expand Down Expand Up @@ -77,9 +76,6 @@ class ClientConfig:
)
"""Push notification configurations to use for every request."""

extensions: list[str] = dataclasses.field(default_factory=list)
"""A list of extension URIs the client supports."""


ClientEvent = tuple[StreamResponse, Task | None]

Expand Down Expand Up @@ -130,12 +126,9 @@ async def __aexit__(
@abstractmethod
async def send_message(
self,
request: Message,
request: SendMessageRequest,
*,
configuration: SendMessageConfiguration | None = None,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Sends a message to the server.

Expand All @@ -154,7 +147,6 @@ async def get_task(
request: GetTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

Expand All @@ -173,7 +165,6 @@ async def cancel_task(
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

Expand All @@ -183,7 +174,6 @@ async def create_task_push_notification_config(
request: CreateTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

Expand All @@ -193,7 +183,6 @@ async def get_task_push_notification_config(
request: GetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

Expand All @@ -203,7 +192,6 @@ async def list_task_push_notification_configs(
request: ListTaskPushNotificationConfigsRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task."""

Expand All @@ -213,7 +201,6 @@ async def delete_task_push_notification_config(
request: DeleteTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task."""

Expand All @@ -223,7 +210,6 @@ async def subscribe(
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
Expand All @@ -235,7 +221,6 @@ async def get_extended_agent_card(
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
Expand Down
Loading
Loading