diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3a5cc0..1e20b1d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: - name: Install package and test deps run: | pip install -e . - pip install pytest pytest-cov pydantic-ai crewai + pip install pytest pytest-cov pydantic-ai crewai "a2a-sdk[http-server]>=0.3.0" - name: Run unit tests with coverage run: | diff --git a/.gitignore b/.gitignore index 90645e7..81f7c3a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ .Python env/ venv/ +.venv/ ENV/ env.bak/ venv.bak/ diff --git a/README.md b/README.md index c64f00b..c278078 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,9 @@ export GRADIENT_MODEL_ACCESS_KEY=your_gradient_key # Optional: Enable verbose trace logging export GRADIENT_VERBOSE=1 + +# Optional: A2A protocol — base URL for AgentCard discovery +export A2A_BASE_URL=https://your-app.ondigitalocean.app ``` ## Project Structure @@ -350,6 +353,112 @@ The Gradient ADK is designed to work with any Python-based AI agent framework: - ✅ **CrewAI** - Use trace decorators for agent and task execution - ✅ **Custom Frameworks** - Use trace decorators for any function +## A2A Protocol Support + +The Gradient ADK supports the [Agent-to-Agent (A2A) protocol v0.3.0](https://github.com/google/A2A), enabling any `@entrypoint` agent to communicate with A2A-compatible clients. Install with `pip install gradient-adk[a2a]`. + +### Wrapping an Agent with A2A + +Any `@entrypoint` agent can be exposed as an A2A server with no code changes: + +```python +from gradient_adk import entrypoint +from gradient_adk.a2a import create_a2a_server + +@entrypoint +async def my_agent(data: dict, context) -> dict: + return {"output": f"You said: {data.get('prompt', '')}"} + +app = create_a2a_server(my_agent) +``` + +Run with `uvicorn my_module:app --host 0.0.0.0 --port 8000`. The agent is discoverable at `/.well-known/agent-card.json` and accepts JSON-RPC calls (`message/send`, `tasks/get`, `tasks/cancel`). + +### How the Protocol Works + +A2A uses a discover-then-call pattern over JSON-RPC. Here is the full client-server flow: + +1. **Discover** — The client fetches the AgentCard at `GET /.well-known/agent-card.json`. This returns the agent's name, transport URL, supported capabilities, and input/output modes. The client uses this to decide whether it can talk to this agent. + +2. **Send** — The client sends a message via `POST /` with JSON-RPC method `message/send`. The server validates the message (text-only in MVP), creates a task, executes the agent, and returns a `Task` object with a `taskId` and current status. + +3. **Poll** — The client checks task progress via `tasks/get` with the `taskId`. Once the task reaches a terminal state (`completed`, `failed`, or `canceled`), the response includes the agent's output in the task artifacts. The `historyLength` parameter controls how much conversation history is returned. + +4. **Cancel** (optional) — The client can request cancellation via `tasks/cancel`. This is best-effort and idempotent — if the agent already finished, the cancel is a no-op. + +``` +Client Server + │ │ + ├── GET /.well-known/agent-card.json ──► AgentCard (capabilities, URL) + │ │ + ├── POST / message/send ──────────────► Create task → Execute agent + │◄─────────────────── Task {id, status} │ + │ │ + ├── POST / tasks/get ─────────────────► Return task state + artifacts + │◄──────────── Task {id, status, result} │ + │ │ + └── POST / tasks/cancel ──────────────► Best-effort cancellation +``` + +### Deploying to DigitalOcean App Platform + +When you deploy to App Platform, the public URL is assigned after deployment. The A2A server needs this URL for the AgentCard so that clients know where to send requests. The workflow is: + +1. **Deploy your agent** to App Platform as usual with `gradient agent deploy` +2. **Get your app's public URL** from the App Platform dashboard (e.g., `https://your-agent-abc123.ondigitalocean.app`) +3. **Set the environment variable** in your app's settings: + ```bash + A2A_BASE_URL=https://your-agent-abc123.ondigitalocean.app + ``` +4. **Redeploy** — the agent restarts and the AgentCard now advertises the correct public URL + +For local development, no configuration is needed — it defaults to `http://localhost:8000`. + +### Calling a Remote A2A Agent from Another Agent + +Once deployed, any A2A-compatible agent or client can call your agent: + +```python +import httpx + +# Discover the remote agent +card = httpx.get("https://your-agent.ondigitalocean.app/.well-known/agent-card.json").json() +rpc_url = card["url"] + +# Send a message +response = httpx.post(rpc_url, json={ + "jsonrpc": "2.0", "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello from another agent!"}], + "message_id": "msg-1", + "kind": "message", + } + }, +}) +task = response.json()["result"] + +# Poll until done +result = httpx.post(rpc_url, json={ + "jsonrpc": "2.0", "id": "2", + "method": "tasks/get", + "params": {"id": task["id"]}, +}).json()["result"] +``` + +See `examples/a2a/client.py` for a complete async client with discovery, send, poll, and cancel. + +### Supported Operations + +- **`message/send`**: Send a message to the agent, creates or continues a task +- **`tasks/get`**: Poll task state and retrieve results (supports `historyLength`) +- **`tasks/cancel`**: Best-effort task cancellation (idempotent) +- **Agent Discovery**: `GET /.well-known/agent-card.json` for capabilities and transport URL + +Text-only input/output (`text/plain`) in the current release. Streaming, push notifications, and authenticated extended cards are explicitly disabled via AgentCard capability flags. + ## Support - **Templates/Examples**: [https://github.com/digitalocean/gradient-adk-templates](https://github.com/digitalocean/gradient-adk-templates) diff --git a/examples/a2a/client.py b/examples/a2a/client.py new file mode 100644 index 0000000..ce10a1e --- /dev/null +++ b/examples/a2a/client.py @@ -0,0 +1,131 @@ +"""Sample A2A client demonstrating discovery, send, polling, and cancel operations. + +Usage: + pip install gradient-adk[a2a] httpx + python examples/a2a/client.py +""" + +import asyncio +import httpx + + +async def discover_agent(base_url: str) -> dict: + """Discover agent capabilities via AgentCard.""" + async with httpx.AsyncClient() as client: + response = await client.get(f"{base_url}/.well-known/agent-card.json") + response.raise_for_status() + agent_card = response.json() + print(f"Discovered agent: {agent_card['name']}") + print(f"Transport URL: {agent_card['url']}") + return agent_card + + +async def send_message(rpc_url: str, message_text: str) -> dict: + """Send a message to the agent.""" + async with httpx.AsyncClient() as client: + response = await client.post( + rpc_url, + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": message_text}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + response.raise_for_status() + result = response.json() + if "error" in result: + raise Exception(f"Error: {result['error']}") + return result["result"] + + +async def get_task(rpc_url: str, task_id: str) -> dict: + """Poll task status.""" + async with httpx.AsyncClient() as client: + response = await client.post( + rpc_url, + json={ + "jsonrpc": "2.0", + "id": "2", + "method": "tasks/get", + "params": {"id": task_id}, + }, + ) + response.raise_for_status() + result = response.json() + if "error" in result: + raise Exception(f"Error: {result['error']}") + return result["result"] + + +async def cancel_task(rpc_url: str, task_id: str) -> dict: + """Cancel a task.""" + async with httpx.AsyncClient() as client: + response = await client.post( + rpc_url, + json={ + "jsonrpc": "2.0", + "id": "3", + "method": "tasks/cancel", + "params": {"id": task_id}, + }, + ) + response.raise_for_status() + result = response.json() + if "error" in result: + raise Exception(f"Error: {result['error']}") + return result["result"] + + +async def main(): + """Demonstrate A2A client operations.""" + base_url = "http://localhost:8000" + rpc_url = f"{base_url}/" + + print("=== A2A Client Demo ===\n") + + # 1. Discover agent + print("1. Discovering agent...") + agent_card = await discover_agent(base_url) + print() + + # 2. Send message + print("2. Sending message...") + send_result = await send_message(rpc_url, "Hello, A2A!") + task_id = send_result["id"] + print(f" Task ID: {task_id}") + print(f" Status: {send_result['status']['state']}") + print() + + # 3. Poll task status + print("3. Polling task status...") + task = await get_task(rpc_url, task_id) + print(f" Task ID: {task['id']}") + print(f" Status: {task['status']['state']}") + if task.get("response"): + parts = task["response"].get("parts", []) + if parts: + print(f" Response: {parts[0].get('text', 'N/A')}") + print() + + # 4. Cancel task (example) + print("4. Canceling task (example)...") + send_result2 = await send_message(rpc_url, "This will be canceled") + task_id2 = send_result2["id"] + cancel_result = await cancel_task(rpc_url, task_id2) + print(f" Task ID: {cancel_result['id']}") + print(f" Status: {cancel_result['status']['state']}") + print() + + print("=== Demo Complete ===") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/a2a/server.py b/examples/a2a/server.py new file mode 100644 index 0000000..8571502 --- /dev/null +++ b/examples/a2a/server.py @@ -0,0 +1,70 @@ +"""Example A2A server agent using Gradient ADK @entrypoint decorator. + +This example demonstrates how to create an A2A-compatible agent that can be +accessed via the A2A protocol using the Gradient ADK. + +To run this agent: +1. Ensure you have gradient-adk and a2a-sdk installed: + pip install gradient-adk[a2a] +2. Set the base URL for discovery (optional, defaults to localhost for dev): + export A2A_BASE_URL=https://your-app.ondigitalocean.app +3. Run: gradient agent run +4. The agent will be available with both Gradient and A2A protocols +""" + +from gradient_adk import entrypoint + + +@entrypoint +async def echo_agent(data: dict, context) -> dict: + """A simple echo agent that repeats the user's input. + + This agent: + - Receives text input via A2A protocol + - Echoes it back with a prefix + - Works with both Gradient /run endpoint and A2A protocol + + Args: + data: Dictionary containing the user's input. + For A2A protocol, this will contain {"prompt": "user text"} + context: Request context (not used in this simple example) + + Returns: + Dictionary with the agent's response. + For A2A protocol, the "output" key will be extracted as the response. + """ + user_input = data.get("prompt", "") + + if not user_input: + return {"output": "No input provided. Please send a message."} + + response = f"Echo: {user_input}" + + return {"output": response} + + +@entrypoint +async def greeting_agent(data: dict, context) -> dict: + """A greeting agent that responds with a personalized message. + + This is an alternative example showing how to create a more + sophisticated agent that still works with A2A protocol. + + Args: + data: Dictionary containing the user's input + context: Request context + + Returns: + Dictionary with the agent's response + """ + user_input = data.get("prompt", "").strip() + + if not user_input: + return {"output": "Hello! What's your name?"} + + if user_input.lower().startswith("hello"): + return {"output": f"Hello! Nice to meet you. You said: {user_input}"} + elif user_input.lower().startswith("hi"): + return {"output": f"Hi there! You said: {user_input}"} + else: + return {"output": f"Greetings! You said: {user_input}"} diff --git a/gradient_adk/a2a/__init__.py b/gradient_adk/a2a/__init__.py new file mode 100644 index 0000000..ec9ace9 --- /dev/null +++ b/gradient_adk/a2a/__init__.py @@ -0,0 +1,11 @@ +"""A2A protocol v0.3.0 integration for Gradient agents. + +Public API: + create_a2a_server: Create an A2A-enabled FastAPI server for a Gradient agent +""" + +from gradient_adk.a2a.infrastructure.server import create_a2a_server + +__all__ = [ + "create_a2a_server", +] diff --git a/gradient_adk/a2a/adapters/__init__.py b/gradient_adk/a2a/adapters/__init__.py new file mode 100644 index 0000000..781d456 --- /dev/null +++ b/gradient_adk/a2a/adapters/__init__.py @@ -0,0 +1 @@ +"""Adapters layer - integrates with external systems (SDK, Gradient).""" diff --git a/gradient_adk/a2a/adapters/primary/__init__.py b/gradient_adk/a2a/adapters/primary/__init__.py new file mode 100644 index 0000000..d313d4c --- /dev/null +++ b/gradient_adk/a2a/adapters/primary/__init__.py @@ -0,0 +1 @@ +"""Primary adapters - implement external interfaces (A2A SDK).""" diff --git a/gradient_adk/a2a/adapters/primary/a2a_executor.py b/gradient_adk/a2a/adapters/primary/a2a_executor.py new file mode 100644 index 0000000..0c1509e --- /dev/null +++ b/gradient_adk/a2a/adapters/primary/a2a_executor.py @@ -0,0 +1,122 @@ +"""Primary adapter - implements A2A SDK interface.""" + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue + +from gradient_adk.a2a.domain.models import GradientOutput, TransformationError +from gradient_adk.a2a.domain.transformation import MessageTransformer, OutputExtractor +from gradient_adk.a2a.domain.ports import AgentPort +from gradient_adk.logging import get_logger + +logger = get_logger(__name__) + + +class A2AExecutorAdapter(AgentExecutor): + """ + Implements SDK's AgentExecutor interface. + + Orchestrates transformation and execution. + Note: Validation happens in ValidatingRequestHandler BEFORE tasks are created. + """ + + def __init__( + self, + transformer: MessageTransformer, + output_extractor: OutputExtractor, + agent_client: AgentPort, + event_publisher_factory, + ): + """All dependencies injected via constructor.""" + self.transformer = transformer + self.output_extractor = output_extractor + self.agent_client = agent_client + self.event_publisher_factory = event_publisher_factory + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """ + Execute agent for A2A request. + + Flow: + 1. Extract user input (already validated by ValidatingRequestHandler) + 2. Transform to Gradient format + 3. Execute agent + 4. Extract output + 5. Publish result + + Note: No validation here - ValidatingRequestHandler validates BEFORE + creating tasks, so we can assume input is valid. + + Exception Handling: + Catches all exceptions from agent execution (agent functions can raise + any exception type). Logs full exception with traceback for debugging + and publishes failure event before re-raising to maintain exception + propagation. + """ + event_publisher = self.event_publisher_factory( + event_queue, + context.context_id, + context.task_id + ) + + try: + user_text = context.get_user_input() + + gradient_input = self.transformer.to_gradient_input_from_text( + user_text, + session_id=context.context_id, + ) + + agent_result = await self.agent_client.execute(gradient_input) + + text_result = self.output_extractor.extract_text(agent_result) + if text_result.is_err(): + await event_publisher.publish_failed( + context.task_id, + text_result.error, + ) + return + + output = GradientOutput(text=text_result.value) + response_message = self.transformer.from_gradient_output(output) + + await event_publisher.publish_completed( + context.task_id, + response_message, + ) + + except Exception as e: + logger.error( + "Agent execution failed", + task_id=context.task_id, + context_id=context.context_id, + exception_type=type(e).__name__, + exc_info=True, + ) + + exception_type = type(e).__name__ + error_message = f"Agent execution failed: {exception_type}: {str(e)}" + await event_publisher.publish_failed( + context.task_id, + TransformationError( + message=error_message, + code="AGENT_EXECUTION_ERROR" + ), + ) + raise + + async def cancel( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Cancel task execution (best-effort).""" + event_publisher = self.event_publisher_factory( + event_queue, + context.context_id, + context.task_id + ) + await event_publisher.publish_canceled(context.task_id) diff --git a/gradient_adk/a2a/adapters/secondary/__init__.py b/gradient_adk/a2a/adapters/secondary/__init__.py new file mode 100644 index 0000000..bdc4e36 --- /dev/null +++ b/gradient_adk/a2a/adapters/secondary/__init__.py @@ -0,0 +1 @@ +"""Secondary adapters - implement domain ports for external integrations.""" diff --git a/gradient_adk/a2a/adapters/secondary/agent_client.py b/gradient_adk/a2a/adapters/secondary/agent_client.py new file mode 100644 index 0000000..28d1f65 --- /dev/null +++ b/gradient_adk/a2a/adapters/secondary/agent_client.py @@ -0,0 +1,47 @@ +"""Secondary adapter - calls Gradient agent.""" + +import inspect +from typing import Any + +from gradient_adk.decorator import RequestContext as GradientRequestContext + +from gradient_adk.a2a.domain.models import GradientInput +from gradient_adk.a2a.domain.ports import AgentPort + + +class GradientAgentClient: + """Executes Gradient agents. Implements AgentPort. + + Only supports non-streaming agent functions. Async generators + (streaming agents) are rejected at construction time. + """ + + def __init__(self, agent_func: Any, input_key: str = "prompt"): + """ + Args: + agent_func: The @entrypoint decorated function (non-streaming) + input_key: Key to use when passing prompt to agent + """ + if inspect.isasyncgenfunction(agent_func): + raise TypeError( + f"Streaming agent functions are not supported in A2A mode. " + f"Agent '{getattr(agent_func, '__name__', agent_func)}' is an async generator." + ) + self.agent_func = agent_func + self.input_key = input_key + + async def execute(self, agent_input: GradientInput) -> Any: + """Execute Gradient agent with given input.""" + gradient_data = {self.input_key: agent_input.prompt} + + gradient_context = GradientRequestContext( + session_id=agent_input.session_id, + headers={}, + ) + + if inspect.iscoroutinefunction(self.agent_func): + result = await self.agent_func(gradient_data, gradient_context) + else: + result = self.agent_func(gradient_data, gradient_context) + + return result diff --git a/gradient_adk/a2a/adapters/secondary/event_publisher.py b/gradient_adk/a2a/adapters/secondary/event_publisher.py new file mode 100644 index 0000000..35c74e7 --- /dev/null +++ b/gradient_adk/a2a/adapters/secondary/event_publisher.py @@ -0,0 +1,63 @@ +"""Secondary adapter - publishes events to A2A SDK.""" + +from a2a.server.tasks import TaskUpdater +from a2a.server.events import EventQueue +from a2a.types import TaskState +from a2a.utils import new_agent_text_message + +from gradient_adk.a2a.domain.models import DomainMessage, ValidationError, TransformationError +from gradient_adk.a2a.domain.ports import EventPort + + +class A2AEventPublisher: + """Publishes events to SDK EventQueue. Implements EventPort.""" + + def __init__(self, event_queue: EventQueue, context_id: str, task_id: str): + """ + Args: + event_queue: SDK event queue + context_id: Context identifier + task_id: Task identifier + """ + self.task_updater = TaskUpdater(event_queue, task_id, context_id) + self.context_id = context_id + self.task_id = task_id + + async def publish_completed( + self, + task_id: str, + message: DomainMessage, + ) -> None: + """Publish task completion.""" + sdk_message = new_agent_text_message( + text=message.text, + context_id=self.context_id, + task_id=task_id, + ) + + await self.task_updater.update_status( + TaskState.completed, + message=sdk_message, + ) + + async def publish_failed( + self, + task_id: str, + error: ValidationError | TransformationError, + ) -> None: + """Publish task failure.""" + error_text = f"Agent error: {error.message}" + sdk_message = new_agent_text_message( + text=error_text, + context_id=self.context_id, + task_id=task_id, + ) + + await self.task_updater.update_status( + TaskState.failed, + message=sdk_message, + ) + + async def publish_canceled(self, task_id: str) -> None: + """Publish task cancellation.""" + await self.task_updater.update_status(TaskState.canceled) diff --git a/gradient_adk/a2a/domain/__init__.py b/gradient_adk/a2a/domain/__init__.py new file mode 100644 index 0000000..ea437af --- /dev/null +++ b/gradient_adk/a2a/domain/__init__.py @@ -0,0 +1,25 @@ +"""Domain layer - pure business logic with no I/O dependencies.""" + +from gradient_adk.a2a.domain.models import ( + MessageRole, + DomainMessage, + GradientInput, + GradientOutput, + ValidationError, + TransformationError, + Ok, + Err, + Result, +) + +__all__ = [ + "MessageRole", + "DomainMessage", + "GradientInput", + "GradientOutput", + "ValidationError", + "TransformationError", + "Ok", + "Err", + "Result", +] diff --git a/gradient_adk/a2a/domain/constants.py b/gradient_adk/a2a/domain/constants.py new file mode 100644 index 0000000..4a47988 --- /dev/null +++ b/gradient_adk/a2a/domain/constants.py @@ -0,0 +1,8 @@ +"""Domain constants for A2A service configuration.""" + +# A2A protocol version +A2A_PROTOCOL_VERSION = "0.3.0" + +__all__ = [ + "A2A_PROTOCOL_VERSION", +] diff --git a/gradient_adk/a2a/domain/models.py b/gradient_adk/a2a/domain/models.py new file mode 100644 index 0000000..c3bd7fe --- /dev/null +++ b/gradient_adk/a2a/domain/models.py @@ -0,0 +1,85 @@ +"""Domain models - immutable data structures.""" + +from dataclasses import dataclass +from typing import Generic, TypeVar +from enum import Enum + + +class MessageRole(str, Enum): + """Message roles in conversation.""" + USER = "user" + AGENT = "agent" + + +@dataclass(frozen=True) +class DomainMessage: + """Pure domain representation of an A2A message.""" + text: str + role: MessageRole + has_file_parts: bool = False + has_data_parts: bool = False + + @property + def is_text_only(self) -> bool: + """Check if message contains only text.""" + return not (self.has_file_parts or self.has_data_parts) + + +@dataclass(frozen=True) +class GradientInput: + """Input format expected by Gradient agents.""" + prompt: str + session_id: str | None = None + + +@dataclass(frozen=True) +class GradientOutput: + """Output format returned by Gradient agents.""" + text: str + metadata: dict | None = None + + +@dataclass(frozen=True) +class ValidationError: + """Domain validation error.""" + message: str + code: str = "VALIDATION_ERROR" + + +@dataclass(frozen=True) +class TransformationError: + """Domain transformation error.""" + message: str + code: str = "TRANSFORMATION_ERROR" + + +# Result type for functional error handling +T = TypeVar('T') +E = TypeVar('E') + + +@dataclass(frozen=True) +class Ok(Generic[T]): + """Success result.""" + value: T + + def is_ok(self) -> bool: + return True + + def is_err(self) -> bool: + return False + + +@dataclass(frozen=True) +class Err(Generic[E]): + """Error result.""" + error: E + + def is_ok(self) -> bool: + return False + + def is_err(self) -> bool: + return True + + +Result = Ok[T] | Err[E] diff --git a/gradient_adk/a2a/domain/ports.py b/gradient_adk/a2a/domain/ports.py new file mode 100644 index 0000000..d7ffebd --- /dev/null +++ b/gradient_adk/a2a/domain/ports.py @@ -0,0 +1,42 @@ +"""Domain ports - abstract interfaces.""" + +from typing import Protocol, Any +from gradient_adk.a2a.domain.models import GradientInput, DomainMessage, ValidationError, TransformationError + + +class AgentPort(Protocol): + """Port for executing Gradient agents.""" + + async def execute( + self, + agent_input: GradientInput + ) -> Any: + """ + Execute agent with given input. + Returns: Agent result (any type) + """ + ... + + +class EventPort(Protocol): + """Port for publishing events.""" + + async def publish_completed( + self, + task_id: str, + message: DomainMessage, + ) -> None: + """Publish task completion event.""" + ... + + async def publish_failed( + self, + task_id: str, + error: ValidationError | TransformationError, + ) -> None: + """Publish task failure event.""" + ... + + async def publish_canceled(self, task_id: str) -> None: + """Publish task cancellation.""" + ... diff --git a/gradient_adk/a2a/domain/transformation.py b/gradient_adk/a2a/domain/transformation.py new file mode 100644 index 0000000..24a09d4 --- /dev/null +++ b/gradient_adk/a2a/domain/transformation.py @@ -0,0 +1,92 @@ +"""Domain transformation logic.""" + +from typing import Any + +from gradient_adk.a2a.domain.models import ( + DomainMessage, + GradientInput, + GradientOutput, + TransformationError, + MessageRole, + Result, + Ok, + Err, +) + + +class MessageTransformer: + """Transforms messages between A2A and Gradient formats.""" + + def __init__(self, input_key: str = "prompt"): + """ + Args: + input_key: Key to use in Gradient input dict + """ + self.input_key = input_key + + def to_gradient_input( + self, + message: DomainMessage, + session_id: str | None = None, + ) -> GradientInput: + """Transform domain message to Gradient input format.""" + return GradientInput( + prompt=message.text, + session_id=session_id, + ) + + def to_gradient_input_from_text( + self, + text: str, + session_id: str | None = None, + ) -> GradientInput: + """Create Gradient input from text string directly.""" + return GradientInput( + prompt=text, + session_id=session_id, + ) + + def from_gradient_output( + self, + output: GradientOutput + ) -> DomainMessage: + """Transform Gradient output to domain message.""" + return DomainMessage( + text=output.text, + role=MessageRole.AGENT, + has_file_parts=False, + has_data_parts=False, + ) + + +class OutputExtractor: + """Extracts output from various Gradient agent result formats.""" + + def __init__(self, output_keys: list[str] | None = None): + """ + Args: + output_keys: Keys to try in order + """ + self.output_keys = output_keys or ["output", "response", "result"] + + def extract_text(self, result: Any) -> Result[str, TransformationError]: + """ + Extract text from agent result. + Handles: string, dict, other types. + """ + if isinstance(result, str): + return Ok(result) + + if isinstance(result, dict): + for key in self.output_keys: + if key in result and result[key]: + return Ok(str(result[key])) + return Ok(str(result)) + + try: + return Ok(str(result)) + except Exception as e: + return Err(TransformationError( + message=f"Failed to convert result to text: {str(e)}", + code="TEXT_EXTRACTION_FAILED" + )) diff --git a/gradient_adk/a2a/domain/validation.py b/gradient_adk/a2a/domain/validation.py new file mode 100644 index 0000000..12ca8a1 --- /dev/null +++ b/gradient_adk/a2a/domain/validation.py @@ -0,0 +1,56 @@ +"""Domain validation logic.""" + +from gradient_adk.a2a.domain.models import DomainMessage, ValidationError, Result, Ok, Err + + +class MessageValidator: + """Validates domain messages according to business rules.""" + + def __init__(self): + """No dependencies - pure validation.""" + pass + + def validate_text_only( + self, + message: DomainMessage + ) -> Result[DomainMessage, ValidationError]: + """ + Validate that message contains only text parts. + MVP supports text/plain only. + """ + if not message.is_text_only: + return Err(ValidationError( + message="Only text/plain input is supported in MVP. " + "FilePart and DataPart are not supported.", + code="UNSUPPORTED_CONTENT_TYPE" + )) + + return Ok(message) + + def validate_non_empty( + self, + message: DomainMessage + ) -> Result[DomainMessage, ValidationError]: + """Validate that message text is not empty.""" + if not message.text.strip(): + return Err(ValidationError( + message="Message text cannot be empty", + code="EMPTY_MESSAGE" + )) + + return Ok(message) + + def validate_all( + self, + message: DomainMessage + ) -> Result[DomainMessage, ValidationError]: + """Run all validations. Returns first error or Ok.""" + result = self.validate_text_only(message) + if result.is_err(): + return result + + result = self.validate_non_empty(message) + if result.is_err(): + return result + + return Ok(message) diff --git a/gradient_adk/a2a/infrastructure/__init__.py b/gradient_adk/a2a/infrastructure/__init__.py new file mode 100644 index 0000000..5d23516 --- /dev/null +++ b/gradient_adk/a2a/infrastructure/__init__.py @@ -0,0 +1 @@ +"""Infrastructure layer - setup, configuration, and dependency wiring.""" diff --git a/gradient_adk/a2a/infrastructure/composition.py b/gradient_adk/a2a/infrastructure/composition.py new file mode 100644 index 0000000..1a74de8 --- /dev/null +++ b/gradient_adk/a2a/infrastructure/composition.py @@ -0,0 +1,82 @@ +"""Dependency injection - compose the application.""" + +from typing import Any + +from a2a.server.tasks import InMemoryTaskStore + +from gradient_adk.a2a.domain.validation import MessageValidator +from gradient_adk.a2a.domain.transformation import MessageTransformer, OutputExtractor +from gradient_adk.a2a.adapters.primary.a2a_executor import A2AExecutorAdapter +from gradient_adk.a2a.adapters.secondary.agent_client import GradientAgentClient +from gradient_adk.a2a.adapters.secondary.event_publisher import A2AEventPublisher +from gradient_adk.a2a.infrastructure.validating_handler import ValidatingRequestHandler + + +def compose_executor( + agent_func: Any, + input_key: str = "prompt", + output_keys: list[str] | None = None, +) -> A2AExecutorAdapter: + """ + Compose all dependencies and create executor. + + Args: + agent_func: Gradient agent function + input_key: Key for agent input (default: "prompt") + output_keys: Keys to try for output + + Returns: + Fully composed A2AExecutorAdapter + """ + transformer = MessageTransformer(input_key=input_key) + output_extractor = OutputExtractor(output_keys=output_keys) + + agent_client = GradientAgentClient(agent_func, input_key=input_key) + + def event_publisher_factory(event_queue, context_id, task_id): + return A2AEventPublisher(event_queue, context_id, task_id) + + executor = A2AExecutorAdapter( + transformer=transformer, + output_extractor=output_extractor, + agent_client=agent_client, + event_publisher_factory=event_publisher_factory, + ) + + return executor + + +def compose_request_handler( + agent_func: Any, + input_key: str = "prompt", + output_keys: list[str] | None = None, +) -> ValidatingRequestHandler: + """ + Compose request handler with validation and execution. + + This creates a ValidatingRequestHandler that: + 1. Validates messages using domain validator (before task creation) + 2. Creates tasks for valid messages + 3. Executes via A2AExecutorAdapter + + Args: + agent_func: Gradient agent function + input_key: Key for agent input (default: "prompt") + output_keys: Keys to try for output + + Returns: + Fully composed ValidatingRequestHandler + """ + executor = compose_executor( + agent_func, + input_key=input_key, + output_keys=output_keys, + ) + + validator = MessageValidator() + + return ValidatingRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), + validator=validator, + ) diff --git a/gradient_adk/a2a/infrastructure/server.py b/gradient_adk/a2a/infrastructure/server.py new file mode 100644 index 0000000..a7653aa --- /dev/null +++ b/gradient_adk/a2a/infrastructure/server.py @@ -0,0 +1,101 @@ +"""FastAPI server setup using A2A SDK.""" + +import os +from typing import Any + +from fastapi import FastAPI + +from a2a.server.apps.jsonrpc import A2AFastAPIApplication +from a2a.types import AgentCard, AgentCapabilities + +from gradient_adk.a2a.infrastructure.composition import compose_request_handler +from gradient_adk.a2a.domain.constants import A2A_PROTOCOL_VERSION + + +def _validate_path(path: str, param_name: str) -> str: + """ + Validate URL path format. + + Args: + path: The path to validate + param_name: Name of the parameter (for error messages) + + Returns: + The validated path + + Raises: + ValueError: If path format is invalid + """ + if not path.startswith("/"): + raise ValueError(f"{param_name} must start with '/'") + if "//" in path: + raise ValueError(f"{param_name} cannot contain double slashes") + return path + + +def create_a2a_server( + agent_func: Any, + rpc_url: str = "/", + agent_card_url: str = "/.well-known/agent-card.json", + base_url: str | None = None, + agent_name: str = "Gradient A2A Agent", + input_key: str = "prompt", + output_keys: list[str] | None = None, +) -> FastAPI: + """ + Create A2A server. + + Args: + agent_func: Gradient agent function + rpc_url: RPC endpoint URL path + agent_card_url: Agent card endpoint URL path + base_url: Base URL for agent (used in AgentCard for discovery). + Reads from A2A_BASE_URL env var if not provided. + Falls back to "http://localhost:8000" for local development. + agent_name: Agent name for AgentCard + input_key: Key for agent input + output_keys: Keys to try for output + + Returns: + FastAPI application + + Raises: + ValueError: If rpc_url or agent_card_url have invalid format + """ + if base_url is None: + base_url = os.getenv("A2A_BASE_URL", "http://localhost:8000") + + rpc_url = _validate_path(rpc_url, "rpc_url") + agent_card_url = _validate_path(agent_card_url, "agent_card_url") + + agent_card = AgentCard( + name=agent_name, + version=A2A_PROTOCOL_VERSION, + description="A2A-enabled Gradient agent", + url=base_url, + capabilities=AgentCapabilities( + streaming=False, + push_notifications=False, + state_transition_history=True, + ), + supports_authenticated_extended_card=False, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[], + ) + + request_handler = compose_request_handler( + agent_func, + input_key=input_key, + output_keys=output_keys, + ) + + app_builder = A2AFastAPIApplication( + agent_card=agent_card, + http_handler=request_handler, + ) + + return app_builder.build( + rpc_url=rpc_url, + agent_card_url=agent_card_url + ) diff --git a/gradient_adk/a2a/infrastructure/validating_handler.py b/gradient_adk/a2a/infrastructure/validating_handler.py new file mode 100644 index 0000000..326a79a --- /dev/null +++ b/gradient_adk/a2a/infrastructure/validating_handler.py @@ -0,0 +1,111 @@ +"""Custom request handler with domain validation. + +This handler validates messages BEFORE creating tasks, ensuring that +validation errors are returned as JSON-RPC errors rather than creating +failed tasks. +""" + +from typing import Any + +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.context import ServerCallContext +from a2a.types import Message, Task, MessageSendParams, FilePart, DataPart, InvalidParamsError +from a2a.utils.errors import ServerError + +from gradient_adk.a2a.domain.validation import MessageValidator +from gradient_adk.a2a.domain.models import DomainMessage, MessageRole + + +class ValidatingRequestHandler(DefaultRequestHandler): + """ + Extends SDK's DefaultRequestHandler to validate messages before task creation. + + Architecture: + Request -> ValidatingHandler.validate() -> DefaultRequestHandler.on_message_send() + | | + Domain Validator Create Task -> Execute + + This ensures: + - Validation happens at the right time (before task creation) + - Invalid requests return JSON-RPC errors (not failed tasks) + - Domain validation rules are enforced consistently + - Executor can assume input is valid + """ + + def __init__(self, agent_executor: Any, task_store: Any, validator: MessageValidator): + """ + Initialize validating request handler. + + Args: + agent_executor: The agent executor to use + task_store: The task store for persistence + validator: Domain validator for business rules + """ + super().__init__(agent_executor, task_store) + self.validator = validator + + async def on_message_send( + self, + params: MessageSendParams, + context: ServerCallContext | None = None, + ) -> Message | Task: + """ + Validate message before delegating to parent handler. + + Args: + params: The message send parameters + context: The server call context + + Returns: + Task or Message response + + Raises: + ServerError: If validation fails (returned as JSON-RPC error) + + Error Code Mapping: + Domain error codes are included in the error message in format: + "[DOMAIN_CODE] Error message" + + Domain codes: + - UNSUPPORTED_CONTENT_TYPE: Message contains file/data parts (not text-only) + - EMPTY_MESSAGE: Message text is empty or whitespace-only + + JSON-RPC error code is always -32602 (Invalid params) for validation errors. + """ + domain_message = self._to_domain_message(params.message) + + validation_result = self.validator.validate_all(domain_message) + + if validation_result.is_err(): + error = validation_result.error + error_message = f"[{error.code}] {error.message}" + raise ServerError(InvalidParamsError(message=error_message)) + + return await super().on_message_send(params, context) + + def _to_domain_message(self, message: Message) -> DomainMessage: + """ + Convert SDK Message to DomainMessage for validation. + + Args: + message: SDK message + + Returns: + Domain message with extracted text and part flags + """ + text_parts = [ + part.root.text + for part in message.parts + if hasattr(part.root, 'text') + ] + text = " ".join(text_parts) if text_parts else "" + + has_file = any(isinstance(part.root, FilePart) for part in message.parts) + has_data = any(isinstance(part.root, DataPart) for part in message.parts) + + return DomainMessage( + text=text, + role=MessageRole.USER, + has_file_parts=has_file, + has_data_parts=has_data, + ) diff --git a/integration_tests/test_a2a/conftest.py b/integration_tests/test_a2a/conftest.py new file mode 100644 index 0000000..2731be7 --- /dev/null +++ b/integration_tests/test_a2a/conftest.py @@ -0,0 +1,32 @@ +import asyncio + +import pytest +from fastapi.testclient import TestClient + + +# Mock Gradient agent function +async def mock_gradient_agent(data: dict, context) -> str: + """Mock Gradient agent for testing.""" + return f"Echo: {data.get('prompt', '')}" + + +async def slow_gradient_agent(data: dict, context) -> str: + """Slow agent for cancel-race testing.""" + await asyncio.sleep(5) + return f"Echo: {data.get('prompt', '')}" + + +@pytest.fixture +def a2a_server(): + """Create A2A server for testing.""" + from gradient_adk.a2a import create_a2a_server + app = create_a2a_server(mock_gradient_agent) + return TestClient(app) + + +@pytest.fixture +def slow_a2a_server(): + """Create A2A server with slow agent for race-condition testing.""" + from gradient_adk.a2a import create_a2a_server + app = create_a2a_server(slow_gradient_agent) + return TestClient(app) diff --git a/integration_tests/test_a2a/test_a2a_integration.py b/integration_tests/test_a2a/test_a2a_integration.py new file mode 100644 index 0000000..483fa51 --- /dev/null +++ b/integration_tests/test_a2a/test_a2a_integration.py @@ -0,0 +1,410 @@ +import pytest + + +def test_message_send(a2a_server): + """Test message/send operation bridges to Gradient agent.""" + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "result" in data + result = data["result"] + + assert isinstance(result, dict) + + if result.get("kind") == "task": + assert "id" in result + assert "status" in result + assert result["status"]["state"] == "completed" + + response_text = None + if result["status"].get("message") and "parts" in result["status"]["message"]: + response_text = result["status"]["message"]["parts"][0]["text"] + elif result.get("history"): + for msg in reversed(result["history"]): + if msg.get("role") == "agent" and "parts" in msg: + response_text = msg["parts"][0]["text"] + break + + assert response_text is not None, "Agent response not found in task" + assert "Echo: Hello" in response_text + elif "parts" in result: + assert "Echo: Hello" in result["parts"][0]["text"] + else: + pytest.fail(f"Unexpected result structure: {result}") + + +def test_tasks_get(a2a_server): + """Test tasks/get operation.""" + send_response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Test"}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + task_id = send_response.json()["result"]["id"] + + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "2", + "method": "tasks/get", + "params": {"id": task_id}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["result"]["id"] == task_id + + +def test_agent_card_discovery(a2a_server): + """Test agent card discovery endpoint.""" + response = a2a_server.get("/.well-known/agent-card.json") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Gradient A2A Agent" + assert "capabilities" in data + assert data["capabilities"]["streaming"] is False + assert data["capabilities"]["pushNotifications"] is False + assert data.get("supportsAuthenticatedExtendedCard") is False + + +def test_agent_card_reflects_configured_url(): + """Test that AgentCard URL reflects the base_url parameter.""" + from gradient_adk.a2a import create_a2a_server + from fastapi.testclient import TestClient + + async def mock_agent(data: dict, context) -> str: + return "test" + + app = create_a2a_server(mock_agent, base_url="https://my-agent.example.com") + client = TestClient(app) + + response = client.get("/.well-known/agent-card.json") + assert response.status_code == 200 + card = response.json() + assert card["url"] == "https://my-agent.example.com" + + app_default = create_a2a_server(mock_agent) + client_default = TestClient(app_default) + + response_default = client_default.get("/.well-known/agent-card.json") + assert response_default.status_code == 200 + card_default = response_default.json() + assert card_default["url"] == "http://localhost:8000" + + +def test_unsupported_streaming(a2a_server): + """Test that SDK automatically returns UnsupportedOperationError for streaming.""" + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/stream", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 + assert "not supported" in data["error"]["message"].lower() + + +def test_unsupported_push_notifications(a2a_server): + """Test that SDK automatically returns UnsupportedOperationError for push notifications.""" + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "tasks/pushNotificationConfig/set", + "params": { + "task_id": "task1", + "pushNotificationConfig": { + "url": "https://example.com", + "token": "token", + }, + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 + assert "not supported" in data["error"]["message"].lower() + + +def test_tasks_get_with_history_length(a2a_server): + """Test tasks/get with historyLength parameter (SDK handles automatically).""" + send_response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Test"}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + task_id = send_response.json()["result"]["id"] + + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "2", + "method": "tasks/get", + "params": {"id": task_id, "historyLength": 1}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["result"]["id"] == task_id + + +def test_tasks_cancel(a2a_server): + """Test tasks/cancel operation (SDK handles race conditions automatically).""" + send_response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Test"}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + task_id = send_response.json()["result"]["id"] + + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "2", + "method": "tasks/cancel", + "params": {"id": task_id}, + }, + ) + assert response.status_code == 200 + data = response.json() + if "result" in data: + assert data["result"]["status"]["state"] == "canceled" + elif "error" in data: + assert data["error"]["code"] in [-32002, -32603] + assert "cancel" in data["error"]["message"].lower() or "terminal" in data["error"]["message"].lower() + else: + pytest.fail(f"Unexpected response: {data}") + + +def test_tasks_cancel_already_completed(a2a_server): + """Test that canceling an already-completed task is idempotent (req 3.4).""" + send_response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Done fast"}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + task_id = send_response.json()["result"]["id"] + assert send_response.json()["result"]["status"]["state"] == "completed" + + # Cancel a completed task — should not error + cancel_response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "2", + "method": "tasks/cancel", + "params": {"id": task_id}, + }, + ) + assert cancel_response.status_code == 200 + data = cancel_response.json() + # Either succeeds silently or returns a protocol error — both are valid + if "result" in data: + assert data["result"]["status"]["state"] in ("completed", "canceled") + elif "error" in data: + # Task already in terminal state — protocol-correct rejection + pass + + # Cancel again — idempotent + cancel_response2 = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "3", + "method": "tasks/cancel", + "params": {"id": task_id}, + }, + ) + assert cancel_response2.status_code == 200 + + +def test_tasks_cancel_race_with_slow_agent(slow_a2a_server): + """Test cancel vs completion race with a slow agent (req 7: explicit race test). + + Sends a message to a slow agent (5s sleep), then immediately cancels. + The outcome should be either 'canceled' or 'completed' — both are valid + under the A2A spec's best-effort cancellation semantics. + """ + def send_message(): + return slow_a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Slow task"}], + "message_id": "msg-race", + "kind": "message", + } + }, + }, + ) + + def cancel_task(task_id): + return slow_a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "2", + "method": "tasks/cancel", + "params": {"id": task_id}, + }, + ) + + # Send the message (may block or return immediately depending on SDK behavior) + send_response = send_message() + assert send_response.status_code == 200 + send_data = send_response.json() + assert "result" in send_data + task_id = send_data["result"]["id"] + + # Immediately attempt cancel + cancel_response = cancel_task(task_id) + assert cancel_response.status_code == 200 + cancel_data = cancel_response.json() + + # Verify task final state — either completed or canceled is acceptable + get_response = slow_a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "3", + "method": "tasks/get", + "params": {"id": task_id}, + }, + ) + assert get_response.status_code == 200 + task_state = get_response.json()["result"]["status"]["state"] + assert task_state in ("completed", "canceled", "failed"), ( + f"Task in unexpected state '{task_state}' after cancel race" + ) + + +def test_text_only_validation_rejects_file_part(a2a_server): + """Test that MVP rejects FilePart per spec requirement.""" + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "file", "file": {"uri": "https://example.com/file.txt"}}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32602 + assert "text/plain" in data["error"]["message"].lower() + + +def test_text_only_validation_rejects_data_part(a2a_server): + """Test that MVP rejects DataPart per spec requirement.""" + response = a2a_server.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "data", "data": {"key": "value"}}], + "message_id": "msg-1", + "kind": "message", + } + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32602 + assert "text/plain" in data["error"]["message"].lower() diff --git a/pyproject.toml b/pyproject.toml index f4c82e7..b14c82f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ dependencies = [ "gradient>=3.10.1", ] +[project.optional-dependencies] +a2a = ["a2a-sdk[http-server]>=0.3.0"] + [project.scripts] gradient = "gradient_adk.cli:run" diff --git a/tests/test_a2a/__init__.py b/tests/test_a2a/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_a2a/conftest.py b/tests/test_a2a/conftest.py new file mode 100644 index 0000000..5d2059b --- /dev/null +++ b/tests/test_a2a/conftest.py @@ -0,0 +1,32 @@ +"""Fixtures for A2A tests. + +Ensures httpx methods are not monkey-patched by the network interceptor, +which can interfere with FastAPI TestClient. +""" + +import httpx +import pytest + +# Save pristine httpx methods at import time (before any test can patch them) +_pristine_client_send = httpx.Client.send +_pristine_client_request = httpx.Client.request +_pristine_async_send = httpx.AsyncClient.send +_pristine_async_request = httpx.AsyncClient.request + + +@pytest.fixture(autouse=True) +def _restore_httpx(): + """Restore pristine httpx methods around each test. + + The network_interceptor module patches httpx globally for request capture. + This fixture ensures TestClient works correctly regardless of test ordering. + """ + httpx.Client.send = _pristine_client_send + httpx.Client.request = _pristine_client_request + httpx.AsyncClient.send = _pristine_async_send + httpx.AsyncClient.request = _pristine_async_request + yield + httpx.Client.send = _pristine_client_send + httpx.Client.request = _pristine_client_request + httpx.AsyncClient.send = _pristine_async_send + httpx.AsyncClient.request = _pristine_async_request diff --git a/tests/test_a2a/domain/__init__.py b/tests/test_a2a/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_a2a/domain/test_transformation.py b/tests/test_a2a/domain/test_transformation.py new file mode 100644 index 0000000..52ca474 --- /dev/null +++ b/tests/test_a2a/domain/test_transformation.py @@ -0,0 +1,147 @@ +"""Unit tests for domain transformation logic.""" + +import pytest +from gradient_adk.a2a.domain.models import ( + DomainMessage, + MessageRole, + GradientInput, + GradientOutput, +) +from gradient_adk.a2a.domain.transformation import MessageTransformer, OutputExtractor + + +class TestMessageTransformer: + """Tests for MessageTransformer.""" + + def setup_method(self): + """Set up test fixtures.""" + self.transformer = MessageTransformer(input_key="prompt") + + def test_to_gradient_input_creates_correct_format(self): + """Test transformation to Gradient input format.""" + message = DomainMessage( + text="What is 2+2?", + role=MessageRole.USER, + ) + + result = self.transformer.to_gradient_input(message, session_id="session-123") + + assert isinstance(result, GradientInput) + assert result.prompt == "What is 2+2?" + assert result.session_id == "session-123" + + def test_to_gradient_input_without_session_id(self): + """Test transformation without session ID.""" + message = DomainMessage( + text="Hello", + role=MessageRole.USER, + ) + + result = self.transformer.to_gradient_input(message) + + assert isinstance(result, GradientInput) + assert result.prompt == "Hello" + assert result.session_id is None + + def test_from_gradient_output_creates_agent_message(self): + """Test transformation from Gradient output.""" + output = GradientOutput(text="The answer is 4", metadata=None) + + result = self.transformer.from_gradient_output(output) + + assert isinstance(result, DomainMessage) + assert result.text == "The answer is 4" + assert result.role == MessageRole.AGENT + assert result.has_file_parts is False + assert result.has_data_parts is False + + +class TestOutputExtractor: + """Tests for OutputExtractor.""" + + def setup_method(self): + """Set up test fixtures.""" + self.extractor = OutputExtractor(output_keys=["output", "response", "result"]) + + def test_extract_text_from_string_result(self): + """Test extracting text from string result.""" + result = self.extractor.extract_text("Hello, world!") + + assert result.is_ok() + assert result.value == "Hello, world!" + + def test_extract_text_from_dict_with_output_key(self): + """Test extracting text from dict with 'output' key.""" + agent_result = {"output": "Response text"} + + result = self.extractor.extract_text(agent_result) + + assert result.is_ok() + assert result.value == "Response text" + + def test_extract_text_from_dict_with_response_key(self): + """Test extracting text from dict with 'response' key.""" + agent_result = {"response": "Response text"} + + result = self.extractor.extract_text(agent_result) + + assert result.is_ok() + assert result.value == "Response text" + + def test_extract_text_from_dict_with_result_key(self): + """Test extracting text from dict with 'result' key.""" + agent_result = {"result": "Response text"} + + result = self.extractor.extract_text(agent_result) + + assert result.is_ok() + assert result.value == "Response text" + + def test_extract_text_uses_key_priority(self): + """Test that extractor uses first matching key in priority order.""" + agent_result = { + "output": "First choice", + "response": "Second choice", + "result": "Third choice", + } + + result = self.extractor.extract_text(agent_result) + + assert result.is_ok() + assert result.value == "First choice" + + def test_extract_text_from_dict_without_known_keys(self): + """Test extracting from dict without known keys converts whole dict.""" + agent_result = {"foo": "bar", "baz": "qux"} + + result = self.extractor.extract_text(agent_result) + + assert result.is_ok() + assert "foo" in result.value + assert "bar" in result.value + + def test_extract_text_from_integer(self): + """Test extracting text from integer.""" + result = self.extractor.extract_text(42) + + assert result.is_ok() + assert result.value == "42" + + def test_extract_text_from_none(self): + """Test extracting text from None.""" + result = self.extractor.extract_text(None) + + assert result.is_ok() + assert result.value == "None" + + def test_extract_text_skips_empty_values(self): + """Test that extractor skips empty values in dict.""" + agent_result = { + "output": "", + "response": "Valid response", + } + + result = self.extractor.extract_text(agent_result) + + assert result.is_ok() + assert result.value == "Valid response" diff --git a/tests/test_a2a/domain/test_validation.py b/tests/test_a2a/domain/test_validation.py new file mode 100644 index 0000000..13289c9 --- /dev/null +++ b/tests/test_a2a/domain/test_validation.py @@ -0,0 +1,132 @@ +"""Unit tests for domain validation logic.""" + +import pytest +from gradient_adk.a2a.domain.models import DomainMessage, MessageRole +from gradient_adk.a2a.domain.validation import MessageValidator + + +class TestMessageValidator: + """Tests for MessageValidator.""" + + def setup_method(self): + """Set up test fixtures.""" + self.validator = MessageValidator() + + def test_validate_text_only_accepts_text_only_message(self): + """Test that validator accepts text-only messages.""" + message = DomainMessage( + text="Hello, world!", + role=MessageRole.USER, + has_file_parts=False, + has_data_parts=False, + ) + + result = self.validator.validate_text_only(message) + + assert result.is_ok() + assert result.value == message + + def test_validate_text_only_rejects_message_with_file_parts(self): + """Test that validator rejects messages with file parts.""" + message = DomainMessage( + text="Hello", + role=MessageRole.USER, + has_file_parts=True, + ) + + result = self.validator.validate_text_only(message) + + assert result.is_err() + assert result.error.code == "UNSUPPORTED_CONTENT_TYPE" + assert "FilePart" in result.error.message + + def test_validate_text_only_rejects_message_with_data_parts(self): + """Test that validator rejects messages with data parts.""" + message = DomainMessage( + text="Hello", + role=MessageRole.USER, + has_data_parts=True, + ) + + result = self.validator.validate_text_only(message) + + assert result.is_err() + assert result.error.code == "UNSUPPORTED_CONTENT_TYPE" + assert "DataPart" in result.error.message + + def test_validate_non_empty_accepts_non_empty_message(self): + """Test that validator accepts non-empty messages.""" + message = DomainMessage( + text="Hello, world!", + role=MessageRole.USER, + ) + + result = self.validator.validate_non_empty(message) + + assert result.is_ok() + assert result.value == message + + def test_validate_non_empty_rejects_empty_message(self): + """Test that validator rejects empty messages.""" + message = DomainMessage( + text="", + role=MessageRole.USER, + ) + + result = self.validator.validate_non_empty(message) + + assert result.is_err() + assert result.error.code == "EMPTY_MESSAGE" + + def test_validate_non_empty_rejects_whitespace_only_message(self): + """Test that validator rejects whitespace-only messages.""" + message = DomainMessage( + text=" \n\t ", + role=MessageRole.USER, + ) + + result = self.validator.validate_non_empty(message) + + assert result.is_err() + assert result.error.code == "EMPTY_MESSAGE" + + def test_validate_all_accepts_valid_message(self): + """Test that validate_all accepts valid messages.""" + message = DomainMessage( + text="Hello, world!", + role=MessageRole.USER, + has_file_parts=False, + has_data_parts=False, + ) + + result = self.validator.validate_all(message) + + assert result.is_ok() + assert result.value == message + + def test_validate_all_rejects_file_parts_first(self): + """Test that validate_all checks text-only first (before non-empty check).""" + message = DomainMessage( + text="", + role=MessageRole.USER, + has_file_parts=True, + ) + + result = self.validator.validate_all(message) + + assert result.is_err() + assert result.error.code == "UNSUPPORTED_CONTENT_TYPE" + + def test_validate_all_rejects_empty_text_after_text_only(self): + """Test that validate_all checks non-empty after text-only validation.""" + message = DomainMessage( + text="", + role=MessageRole.USER, + has_file_parts=False, + has_data_parts=False, + ) + + result = self.validator.validate_all(message) + + assert result.is_err() + assert result.error.code == "EMPTY_MESSAGE" diff --git a/tests/test_a2a/test_full_flow.py b/tests/test_a2a/test_full_flow.py new file mode 100644 index 0000000..a3cfa5f --- /dev/null +++ b/tests/test_a2a/test_full_flow.py @@ -0,0 +1,231 @@ +"""End-to-end tests for A2A integration.""" + +import pytest +import httpx +from fastapi.testclient import TestClient + +from gradient_adk.a2a.infrastructure.server import create_a2a_server + + +# Test agent function (not a pytest test, just a helper) +def echo_agent(data, context): + """Simple test agent that echoes the prompt.""" + prompt = data.get("prompt", "") + return {"output": f"Echo: {prompt}"} + + +@pytest.fixture +def app(): + """Create test FastAPI app.""" + return create_a2a_server( + agent_func=echo_agent, + base_url="http://test", + agent_name="Test A2A Agent", + ) + + +@pytest.fixture +def client(app): + """Create test client.""" + return TestClient(app) + + +class TestAgentDiscovery: + """Test agent discovery endpoint.""" + + def test_agent_card_endpoint_exists(self, client): + """Test that agent card endpoint is accessible.""" + response = client.get("/.well-known/agent-card.json") + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + + def test_agent_card_has_correct_structure(self, client): + """Test that agent card has correct structure.""" + response = client.get("/.well-known/agent-card.json") + card = response.json() + + assert "name" in card + assert "version" in card + assert "capabilities" in card + assert "defaultInputModes" in card + assert "defaultOutputModes" in card + + def test_agent_card_capabilities(self, client): + """Test that agent card capabilities are correct for MVP.""" + response = client.get("/.well-known/agent-card.json") + card = response.json() + + assert card["capabilities"]["streaming"] is False + assert card["capabilities"]["pushNotifications"] is False + assert card["capabilities"]["stateTransitionHistory"] is True + + def test_agent_card_input_output_modes(self, client): + """Test that input/output modes are correct.""" + response = client.get("/.well-known/agent-card.json") + card = response.json() + + assert "text/plain" in card["defaultInputModes"] + assert "text/plain" in card["defaultOutputModes"] + + +class TestMessageSend: + """Test message/send operation.""" + + def test_send_valid_text_message(self, client): + """Test sending a valid text message.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "messageId": "msg-001", + "role": "user", + "parts": [{"type": "text", "text": "Hello, agent!"}] + } + }, + "id": 1 + }) + + assert response.status_code == 200 + result = response.json() + + assert result["jsonrpc"] == "2.0" + assert result["id"] == 1 + assert "result" in result + + result_data = result["result"] + assert "role" in result_data or ("kind" in result_data and result_data["kind"] == "task") + + def test_send_message_with_empty_text(self, client): + """Test that empty messages are rejected.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "messageId": "msg-002", + "role": "user", + "parts": [{"type": "text", "text": ""}] + } + }, + "id": 2 + }) + + assert response.status_code == 200 + result = response.json() + + assert result["jsonrpc"] == "2.0" + assert result["id"] == 2 + + def test_invalid_json_rpc_request(self, client): + """Test that invalid JSON-RPC requests are rejected.""" + response = client.post("/", json={ + "invalid": "request" + }) + + assert response.status_code == 200 + result = response.json() + + assert "error" in result + assert result["error"]["code"] == -32600 + + +class TestTaskOperations: + """Test task operations.""" + + def test_get_nonexistent_task(self, client): + """Test getting a non-existent task.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "tasks/get", + "params": { + "taskId": "nonexistent-task-id" + }, + "id": 3 + }) + + assert response.status_code == 200 + result = response.json() + + assert "error" in result or ("result" in result and "error" in str(result["result"])) + + def test_cancel_nonexistent_task(self, client): + """Test canceling a non-existent task.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "tasks/cancel", + "params": { + "taskId": "nonexistent-task-id" + }, + "id": 4 + }) + + assert response.status_code == 200 + result = response.json() + + assert result["jsonrpc"] == "2.0" + + +class TestUnsupportedOperations: + """Test unsupported operations.""" + + def test_streaming_is_unsupported(self, client): + """Test that message/stream returns unsupported operation.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "message/stream", + "params": { + "message": { + "messageId": "msg-stream-001", + "role": "user", + "parts": [{"type": "text", "text": "Hello"}] + } + }, + "id": 5 + }) + + assert response.status_code == 200 + result = response.json() + + assert "error" in result + assert result["error"]["code"] in [-32004, -32601, -32602, -32603] + + def test_unknown_method(self, client): + """Test that unknown methods return method not found.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "unknown/method", + "params": {}, + "id": 6 + }) + + assert response.status_code == 200 + result = response.json() + + assert "error" in result + assert result["error"]["code"] == -32601 + + +class TestAsyncFlow: + """Test async message flow.""" + + def test_async_message_send_with_test_client(self, client): + """Test async message send using TestClient (simulates async).""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "message": { + "messageId": "msg-async-001", + "role": "user", + "parts": [{"type": "text", "text": "Async hello!"}] + } + }, + "id": 7 + }) + + assert response.status_code == 200 + result = response.json() + assert result["jsonrpc"] == "2.0" + assert "result" in result or "error" in result