From 7db2e90d8de2a56aa30976a78710a16b2452b611 Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Tue, 30 Sep 2025 12:33:51 +0200 Subject: [PATCH 1/6] A2A implementation with AgentExecutor pattern - Maps contextID to conversations ID so that the agent has the needed content - Make use of TaskState completed, failed, working and input_required - Add model card configuration option through yaml file - Uses artifacts updates for the streaming and the final chunk --- lightspeed-stack.yaml | 1 + pyproject.toml | 4 + src/app/endpoints/a2a.py | 806 ++++++++++++++++++ src/app/routers.py | 5 + src/models/config.py | 32 + tests/unit/app/test_routers.py | 3 + .../models/config/test_dump_configuration.py | 2 + 7 files changed, 853 insertions(+) create mode 100644 src/app/endpoints/a2a.py diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 95ebb7315..ba29f85fa 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -2,6 +2,7 @@ name: Lightspeed Core Service (LCS) service: host: 0.0.0.0 port: 8080 + base_url: http://localhost:8080 auth_enabled: false workers: 1 color_log: true diff --git a/pyproject.toml b/pyproject.toml index c30872515..39a94240b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ # Used by JWK token auth handler "aiohttp>=3.12.14", "authlib>=1.6.0", + # Used for A2A protocol support + "a2a-sdk", # OpenAPI exporter "email-validator>=2.2.0", "openai>=1.99.9", @@ -53,6 +55,8 @@ dependencies = [ "psycopg2-binary>=2.9.10", "litellm>=1.75.5.post1", "urllib3==2.6.1", + # Used for agent card configuration + "PyYAML>=6.0.0", ] diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py new file mode 100644 index 000000000..a4df750d0 --- /dev/null +++ b/src/app/endpoints/a2a.py @@ -0,0 +1,806 @@ +"""Handler for A2A (Agent-to-Agent) protocol endpoints.""" + +import asyncio +import json +import logging +import uuid +from datetime import datetime +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Request +from starlette.responses import Response, StreamingResponse + +from a2a.types import ( + AgentCard, + AgentSkill, + AgentProvider, + AgentCapabilities, + Part, + Task, + TaskState, + TextPart, +) +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks.task_updater import TaskUpdater +from a2a.server.apps import A2AStarletteApplication +from a2a.utils import new_agent_text_message, new_task + +from authentication.interface import AuthTuple +from authentication import get_auth_dependency +from authorization.middleware import authorize +from configuration import configuration +from models.config import Action +from models.requests import QueryRequest +from app.endpoints.query import ( + select_model_and_provider_id, + evaluate_model_hints, +) +from app.endpoints.streaming_query import retrieve_response +from client import AsyncLlamaStackClientHolder +from utils.mcp_headers import mcp_headers_dependency +from version import __version__ + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["a2a"]) + +auth_dependency = get_auth_dependency() + + +# ----------------------------- +# Persistent State (multi-turn) +# ----------------------------- +# Keep a single TaskStore instance so tasks persist across requests and +# previous messages remain connected to the current request. +_TASK_STORE = InMemoryTaskStore() + +# Map A2A contextId -> Llama Stack conversationId to preserve history across turns +_CONTEXT_TO_CONVERSATION: dict[str, str] = {} + + +# ----------------------------- +# Agent Executor Implementation +# ----------------------------- +class LightspeedAgentExecutor(AgentExecutor): + """ + Lightspeed Agent Executor for OpenShift Assisted Chat Installer. + + This executor implements the A2A AgentExecutor interface and handles + routing queries to the appropriate LLM backend. + """ + + def __init__( + self, auth_token: str, mcp_headers: dict[str, dict[str, str]] | None = None + ): + """ + Initialize the Lightspeed agent executor. + + Args: + auth_token: Authentication token for the request + mcp_headers: MCP headers for context propagation + """ + self.auth_token = auth_token + self.mcp_headers = mcp_headers or {} + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """ + Execute the agent with the given context and send results to the event queue. + + Args: + context: The request context containing user input and metadata + event_queue: Queue for sending response events + """ + # Get or create task + task = await self._prepare_task(context, event_queue) + + # Process the task with streaming + await self._process_task_streaming( + context, event_queue, task.context_id, task.id + ) + + async def _prepare_task( + self, context: RequestContext, event_queue: EventQueue + ) -> Task: + """ + Get existing task or create a new one. + + Args: + context: The request context + event_queue: Queue for sending events + + Returns: + Task object + """ + task = context.current_task + if not task: + if not context.message: + raise ValueError("No message provided in context") + task = new_task(context.message) + await event_queue.enqueue_event(task) + return task + + async def _process_task_streaming( # pylint: disable=too-many-locals,too-many-branches,too-many-statements + self, + context: RequestContext, + event_queue: EventQueue, + context_id: str, + task_id: str, + ) -> None: + """ + Process the task with streaming updates. + + Args: + context: The request context + event_queue: Queue for sending events + context_id: Context ID for the task + task_id: Task ID + """ + task_updater = TaskUpdater(event_queue, task_id, context_id) + + try: + # Extract user input using SDK utility + user_input = context.get_user_input() + if not user_input: + await task_updater.update_status( + TaskState.input_required, + message=new_agent_text_message( + "I didn't receive any input. " + "How can I help you with OpenShift installation?", + context_id=context_id, + task_id=task_id, + ), + final=True, + ) + return + + preview = user_input[:200] + ("..." if len(user_input) > 200 else "") + logger.info("Processing A2A request: %s", preview) + + # Extract routing metadata from context + metadata = context.message.metadata if context.message else {} + model = metadata.get("model") if metadata else None + provider = metadata.get("provider") if metadata else None + + # Resolve conversation_id from A2A contextId to preserve multi-turn history + a2a_context_id = context_id + conversation_id_hint = _CONTEXT_TO_CONVERSATION.get(a2a_context_id) + logger.info( + "A2A contextId %s maps to conversation_id %s", + a2a_context_id, + conversation_id_hint, + ) + + # Build internal query request with conversation_id for history + query_request = QueryRequest( + query=user_input, + conversation_id=conversation_id_hint, + model=model, + provider=provider, + ) + + # Get LLM client and select model + client = AsyncLlamaStackClientHolder().get_client() + llama_stack_model_id, _model_id, _provider_id = ( + select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints( + user_conversation=None, query_request=query_request + ), + ) + ) + + # Stream response from LLM with status updates + stream, conversation_id = await retrieve_response( + client, + llama_stack_model_id, + query_request, + self.auth_token, + mcp_headers=self.mcp_headers, + ) + + # Persist conversationId for next turn in same A2A context + if conversation_id: + _CONTEXT_TO_CONVERSATION[a2a_context_id] = conversation_id + logger.info( + "Persisted conversation_id %s for A2A contextId %s", + conversation_id, + a2a_context_id, + ) + + # Stream incremental updates: emit working status with text deltas. + # Terminal conditions: + # - turn_awaiting_input -> TaskState.input_required with accumulated text + # - turn_complete -> TaskState.completed (final), leverage contextId for follow-ups + final_event_sent = False + accumulated_text_chunks: list[str] = [] + streamed_any_delta = False + + artifact_id = str(uuid.uuid4()) + async for chunk in stream: + # Extract text from chunk - llama-stack structure + if hasattr(chunk, "event") and chunk.event is not None: + payload = chunk.event.payload + event_type = payload.event_type + + # Handle turn_awaiting_input - request more input with accumulated text + if event_type == "turn_awaiting_input": + logger.debug("Turn awaiting input") + try: + final_text = ( + "" + if streamed_any_delta + else "".join(accumulated_text_chunks) + ) + await task_updater.update_status( + TaskState.input_required, + message=new_agent_text_message( + final_text, + context_id=context_id, + task_id=task_id, + ), + final=True, + ) + final_event_sent = True + logger.info("Input required for task %s", task_id) + except Exception: # pylint: disable=broad-except + logger.debug( + "Error sending input_required status", exc_info=True + ) + # End the stream for this turn after requesting input + break + + # Handle turn_complete - complete the task for this turn + elif event_type == "turn_complete": + logger.debug("Turn complete event") + try: + final_text = ( + "" + if streamed_any_delta + else "".join(accumulated_text_chunks) + ) + # await task_updater.update_status( + # TaskState.completed, + # message=new_agent_text_message( + # final_text, + # context_id=context_id, + # task_id=task_id, + # ), + # final=True, + # ) + task_metadata = { + "conversation_id": str(conversation_id), + "message_id": str(chunk.event.payload.turn.turn_id), + "sources": None + } + + await task_updater.add_artifact( + parts=[Part(root=TextPart(text=final_text))], + artifact_id=artifact_id, + metadata=task_metadata, + append=streamed_any_delta, + last_chunk=True + ) + await task_updater.complete() + final_event_sent = True + except Exception: # pylint: disable=broad-except + logger.debug( + "Error sending completed on turn_complete", + exc_info=True, + ) + logger.info("Turn completed for task %s", task_id) + # End the stream for this turn + break + + # Handle streaming inference tokens + elif event_type == "step_progress": + if hasattr(payload, "delta") and payload.delta.type == "text": + delta_text = payload.delta.text + if delta_text: + accumulated_text_chunks.append(delta_text) + logger.debug("Step progress, delta test: %s", delta_text) + # await task_updater.update_status( + # TaskState.working, + # message=new_agent_text_message( + # delta_text, + # context_id=context_id, + # task_id=task_id, + # ), + # ) + await task_updater.add_artifact( + parts=[Part(root=TextPart(text=delta_text))], + artifact_id=artifact_id, + metadata=None, + append=streamed_any_delta, + ) + streamed_any_delta = True + + # Ensure exactly one terminal status per turn + if not final_event_sent: + try: + final_text = ( + "" if streamed_any_delta else "".join(accumulated_text_chunks) + ) + # await task_updater.update_status( + # TaskState.completed, + # message=new_agent_text_message( + # final_text, + # context_id=context_id, + # task_id=task_id, + # ), + # final=True, + # ) + await task_updater.add_artifact( + parts=[Part(root=TextPart(text=final_text))], + artifact_id=artifact_id, + metadata=None, + append=streamed_any_delta, + last_chunk=True + ) + await task_updater.complete() + except Exception: # pylint: disable=broad-except + logger.debug( + "Error sending fallback completed status", exc_info=True + ) + + except Exception as exc: # pylint: disable=broad-except + logger.error("Error executing agent: %s", str(exc), exc_info=True) + await task_updater.update_status( + TaskState.failed, + message=new_agent_text_message( + f"Sorry, I encountered an error: {str(exc)}", + context_id=context_id, + task_id=task_id, + ), + final=True, + ) + + async def cancel( + self, + context: RequestContext, # pylint: disable=unused-argument + event_queue: EventQueue, # pylint: disable=unused-argument + ) -> None: + """ + Handle task cancellation. + + Args: + context: The request context + event_queue: Queue for sending cancellation events + + Raises: + NotImplementedError: Task cancellation is not currently supported + """ + logger.info("Cancellation requested but not currently supported") + raise NotImplementedError("Task cancellation not currently supported") + + +# ----------------------------- +# Agent Card Configuration +# ----------------------------- +def get_lightspeed_agent_card() -> AgentCard: + """ + Generate the A2A Agent Card for Lightspeed. + + If agent_card_path is configured, loads the agent card from the YAML file. + Otherwise, uses default hardcoded values. + + Returns: + AgentCard: The agent card describing Lightspeed's capabilities. + """ + # Get base URL from configuration or construct it + service_config = configuration.service_configuration + base_url = service_config.base_url if service_config.base_url is not None else "http://localhost:8080" + + # Check if agent card is configured via file + if ( + configuration.customization is not None + and configuration.customization.agent_card_config is not None + ): + config = configuration.customization.agent_card_config + + # Parse skills from config + skills = [ + AgentSkill( + id=skill.get("id"), + name=skill.get("name"), + description=skill.get("description"), + tags=skill.get("tags", []), + input_modes=skill.get("inputModes", []), + output_modes=skill.get("outputModes", []), + examples=skill.get("examples", []), + ) + for skill in config.get("skills", []) + ] + + # Parse provider from config + provider_config = config.get("provider", {}) + provider = AgentProvider( + organization=provider_config.get("organization", ""), + url=provider_config.get("url", ""), + ) + + # Parse capabilities from config + capabilities_config = config.get("capabilities", {}) + capabilities = AgentCapabilities( + streaming=capabilities_config.get("streaming", True), + push_notifications=capabilities_config.get("pushNotifications", False), + state_transition_history=capabilities_config.get( + "stateTransitionHistory", False + ), + ) + + return AgentCard( + name=config.get("name", "Lightspeed AI Assistant"), + description=config.get("description", ""), + version=__version__, + url=f"{base_url}/a2a", + documentation_url=f"{base_url}/docs", + provider=provider, + skills=skills, + default_input_modes=config.get("defaultInputModes", ["text/plain"]), + default_output_modes=config.get("defaultOutputModes", ["text/plain"]), + capabilities=capabilities, + protocol_version="0.2.1", + security=config.get("security", [{"bearer": []}]), + security_schemes=config.get("security_schemes", {}), + ) + + # Fallback to default hardcoded agent card + logger.info("Using default hardcoded agent card (no agent_card_path configured)") + + # Define Lightspeed's skills for OpenShift cluster installation + skills = [ + AgentSkill( + id="cluster_installation_guidance", + name="Cluster Installation Guidance", + description=( + "Provide guidance and assistance for OpenShift cluster " + "installation using assisted-installer" + ), + tags=["openshift", "installation", "assisted-installer"], + input_modes=["text/plain", "application/json"], + output_modes=["text/plain", "application/json"], + examples=[ + "How do I install OpenShift using assisted-installer?", + "What are the prerequisites for OpenShift installation?", + ], + ), + AgentSkill( + id="cluster_configuration_validation", + name="Cluster Configuration Validation", + description=( + "Validate and provide recommendations for OpenShift " + "cluster configuration parameters" + ), + tags=["openshift", "configuration", "validation"], + input_modes=["application/json", "text/plain"], + output_modes=["application/json", "text/plain"], + examples=[ + "Validate my cluster configuration", + "Check if my OpenShift setup meets requirements", + ], + ), + AgentSkill( + id="installation_troubleshooting", + name="Installation Troubleshooting", + description=( + "Help troubleshoot OpenShift cluster installation issues " + "and provide solutions" + ), + tags=["openshift", "troubleshooting", "support"], + input_modes=["text/plain", "application/json"], + output_modes=["text/plain", "application/json"], + examples=[ + "My cluster installation is failing", + "How do I fix installation errors?", + ], + ), + AgentSkill( + id="cluster_requirements_analysis", + name="Cluster Requirements Analysis", + description=( + "Analyze infrastructure requirements for " + "OpenShift cluster deployment" + ), + tags=["openshift", "requirements", "planning"], + input_modes=["application/json", "text/plain"], + output_modes=["application/json", "text/plain"], + examples=[ + "What hardware do I need for OpenShift?", + "Analyze requirements for a 5-node cluster", + ], + ), + ] + + # Provider information + provider = AgentProvider(organization="Red Hat", url="https://redhat.com") + + # Agent capabilities + capabilities = AgentCapabilities( + streaming=True, push_notifications=False, state_transition_history=False + ) + + return AgentCard( + name="OpenShift Assisted Installer AI Assistant", + description=( + "AI-powered assistant specialized in OpenShift cluster " + "installation, configuration, and troubleshooting using " + "assisted-installer backend" + ), + version=__version__, + url=f"{base_url}/a2a", + documentation_url=f"{base_url}/docs", + provider=provider, + skills=skills, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + capabilities=capabilities, + protocol_version="0.2.1", + security=[{"bearer": []}], + security_schemes={}, + ) + + +# ----------------------------- +# FastAPI Endpoints +# ----------------------------- +@router.get("/.well-known/agent.json", response_model=AgentCard) +@router.get("/.well-known/agent-card.json", response_model=AgentCard) +async def get_agent_card( # pylint: disable=unused-argument + auth: Annotated[AuthTuple, Depends(auth_dependency)], +) -> AgentCard: + """ + Serve the A2A Agent Card at the well-known location. + + This endpoint provides the agent card that describes Lightspeed's + capabilities according to the A2A protocol specification. + + Returns: + AgentCard: The agent card describing this agent's capabilities. + """ + try: + logger.info("Serving A2A Agent Card") + agent_card = get_lightspeed_agent_card() + logger.info("Agent Card URL: %s", agent_card.url) + logger.info( + "Agent Card capabilities: streaming=%s", agent_card.capabilities.streaming + ) + return agent_card + except Exception as exc: + logger.error("Error serving A2A Agent Card: %s", str(exc)) + raise + + +def _create_a2a_app(auth_token: str, mcp_headers: dict[str, dict[str, str]]) -> Any: + """ + Create an A2A Starlette application instance with auth context. + + Args: + auth_token: Authentication token for the request + mcp_headers: MCP headers for context propagation + + Returns: + A2A Starlette ASGI application + """ + agent_executor = LightspeedAgentExecutor( + auth_token=auth_token, mcp_headers=mcp_headers + ) + + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, + task_store=_TASK_STORE, + ) + + a2a_app = A2AStarletteApplication( + agent_card=get_lightspeed_agent_card(), + http_handler=request_handler, + ) + + return a2a_app.build() + + +@router.api_route("/a2a", methods=["GET", "POST"], response_model=None) +@authorize(Action.A2A_JSONRPC) +async def handle_a2a_jsonrpc( # pylint: disable=too-many-locals,too-many-statements + request: Request, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> Response | StreamingResponse: + """ + Main A2A JSON-RPC endpoint following the A2A protocol specification. + + This endpoint uses the DefaultRequestHandler from the A2A SDK to handle + all JSON-RPC requests including message/send, message/stream, etc. + + The A2A SDK application is created per-request to include authentication + context while still leveraging FastAPI's authorization middleware. + + Automatically detects streaming requests (message/stream JSON-RPC method) + and returns a StreamingResponse to enable real-time chunk delivery. + + Args: + request: FastAPI request object + auth: Authentication tuple + mcp_headers: MCP headers for context propagation + + Returns: + JSON-RPC response or streaming response + """ + logger.debug("A2A endpoint called: %s %s", request.method, request.url.path) + + # Extract auth token from AuthTuple + # AuthTuple format: (user_id, username, roles, token, ...) + try: + auth_token = auth[3] if len(auth) > 3 else "" + except (IndexError, TypeError): + logger.warning("Failed to extract auth token from auth tuple") + auth_token = "" + + # Create A2A app with auth context + a2a_app = _create_a2a_app(auth_token, mcp_headers) + + # Detect if this is a streaming request by checking the JSON-RPC method + is_streaming_request = False + body = b"" + try: + # Read and parse the request body to check the method + body = await request.body() + logger.debug("A2A request body size: %d bytes", len(body)) + if body: + try: + rpc_request = json.loads(body) + # Check if the method is message/stream + method = rpc_request.get("method", "") + is_streaming_request = method == "message/stream" + logger.info( + "A2A request method: %s, streaming: %s", + method, + is_streaming_request, + ) + except (json.JSONDecodeError, AttributeError) as e: + logger.warning( + "Could not parse A2A request body for method detection: %s", str(e) + ) + except Exception as e: # pylint: disable=broad-except + logger.error("Error detecting streaming request: %s", str(e)) + + # Setup scope for A2A app + scope = request.scope.copy() + scope["path"] = "/" # A2A app expects root path + + # We need to re-provide the body since we already read it + body_sent = False + + async def receive(): + nonlocal body_sent + if not body_sent: + body_sent = True + return {"type": "http.request", "body": body, "more_body": False} + + # After sending body once, delegate to original receive + # This prevents infinite loops - the original receive() will block/disconnect properly + return await request.receive() + + if is_streaming_request: + # Streaming mode: Forward chunks to client as they arrive + logger.info("Handling A2A streaming request") + + # Create queue for passing chunks from ASGI app to response generator + chunk_queue: asyncio.Queue = asyncio.Queue() + + async def streaming_send(message: dict[str, Any]) -> None: + """Send callback that queues chunks for streaming.""" + if message["type"] == "http.response.body": + body_chunk = message.get("body", b"") + if body_chunk: + await chunk_queue.put(body_chunk) + # Signal end of stream if no more body + if not message.get("more_body", False): + logger.debug("Streaming: End of stream signaled") + await chunk_queue.put(None) + + # Run the A2A app in a background task + async def run_a2a_app() -> None: + """Run A2A app and handle any errors.""" + try: + logger.debug("Streaming: Starting A2A app execution") + await a2a_app(scope, receive, streaming_send) + logger.debug("Streaming: A2A app execution completed") + except Exception as exc: # pylint: disable=broad-except + logger.error( + "Error in A2A app during streaming: %s", str(exc), exc_info=True + ) + await chunk_queue.put(None) # Signal end even on error + + # Start the A2A app task + app_task = asyncio.create_task(run_a2a_app()) + + async def response_generator() -> Any: + """Generator that yields chunks from the queue.""" + chunk_count = 0 + try: + while True: + # Get chunk from queue with timeout to prevent hanging + try: + chunk = await asyncio.wait_for(chunk_queue.get(), timeout=300.0) + except asyncio.TimeoutError: + logger.error("Timeout waiting for chunk from A2A app") + break + + if chunk is None: + # End of stream + logger.debug( + "Streaming: Stream ended after %d chunks", chunk_count + ) + break + chunk_count += 1 + logger.debug("Chunk sent to A2A client: %s", str(chunk)) + yield chunk + finally: + # Ensure the app task is cleaned up + if not app_task.done(): + app_task.cancel() + try: + await app_task + except asyncio.CancelledError: + pass + + # Return streaming response immediately + # The status code and headers will be determined by the first chunk + # We can't wait for the response to start because that would cause a deadlock: + # the ASGI app won't send data until the client starts consuming + logger.debug("Streaming: Returning StreamingResponse") + + # Return streaming response with SSE content type for A2A protocol + return StreamingResponse( + response_generator(), + media_type="text/event-stream", + ) + + # Non-streaming mode: Buffer entire response + logger.info("Handling A2A non-streaming request") + + response_started = False + response_body = [] + status_code = 200 + headers = [] + + async def buffering_send(message: dict[str, Any]) -> None: + nonlocal response_started, status_code, headers + if message["type"] == "http.response.start": + response_started = True + status_code = message["status"] + headers = message.get("headers", []) + elif message["type"] == "http.response.body": + response_body.append(message.get("body", b"")) + + await a2a_app(scope, receive, buffering_send) + + # Return the response from A2A app + return Response( + content=b"".join(response_body), + status_code=status_code, + headers=dict((k.decode(), v.decode()) for k, v in headers), + ) + + +@router.get("/a2a/health") +async def a2a_health_check() -> dict[str, str]: + """ + Health check endpoint for A2A service. + + Returns: + Dict with health status information. + """ + return { + "status": "healthy", + "service": "lightspeed-a2a", + "version": __version__, + "a2a_sdk_version": "0.2.1", + "timestamp": datetime.now().isoformat(), + } diff --git a/src/app/routers.py b/src/app/routers.py index ae9cf51ce..77ebde10a 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -20,6 +20,8 @@ tools, # V2 endpoints for Response API support query_v2, + # A2A (Agent-to-Agent) protocol support + a2a, ) @@ -53,3 +55,6 @@ def include_routers(app: FastAPI) -> None: app.include_router(health.router) app.include_router(authorized.router) app.include_router(metrics.router) + + # A2A (Agent-to-Agent) protocol endpoints + app.include_router(a2a.router) diff --git a/src/models/config.py b/src/models/config.py index 8958f18f1..32b357752 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -7,6 +7,7 @@ from enum import Enum from functools import cached_property import re +import yaml import jsonpath_ng from jsonpath_ng.exceptions import JSONPathError @@ -362,6 +363,12 @@ class ServiceConfiguration(ConfigurationBase): description="Service port", ) + base_url: Optional[str] = Field( + None, + title="Base URL", + description="Externally reachable base URL for the service; needed for A2A support.", + ) + auth_enabled: bool = Field( False, title="Authentication enabled", @@ -805,6 +812,12 @@ class Action(str, Enum): # Allow overriding model/provider via request MODEL_OVERRIDE = "model_override" + # A2A (Agent-to-Agent) protocol actions + A2A_AGENT_CARD = "a2a_agent_card" + A2A_TASK_EXECUTION = "a2a_task_execution" + A2A_MESSAGE = "a2a_message" + A2A_JSONRPC = "a2a_jsonrpc" + class AccessRule(ConfigurationBase): """Rule defining what actions a role can perform.""" @@ -1079,6 +1092,8 @@ class Customization(ConfigurationBase): disable_query_system_prompt: bool = False system_prompt_path: Optional[FilePath] = None system_prompt: Optional[str] = None + agent_card_path: Optional[FilePath] = None + agent_card_config: Optional[dict[str, Any]] = None custom_profile: Optional[CustomProfile] = Field(default=None, init=False) @model_validator(mode="after") @@ -1101,6 +1116,23 @@ def check_customization_model(self) -> Self: self.system_prompt = checks.get_attribute_from_file( dict(self), "system_prompt_path" ) + + # Load agent card configuration from YAML file + if self.agent_card_path is not None: + checks.file_check(self.agent_card_path, "agent card") + + try: + with open(self.agent_card_path, "r", encoding="utf-8") as f: + self.agent_card_config = yaml.safe_load(f) + except yaml.YAMLError as e: + raise ValueError( + f"Invalid YAML in agent card file '{self.agent_card_path}': {e}" + ) from e + except OSError as e: + raise ValueError( + f"Unable to read agent card file '{self.agent_card_path}': {e}" + ) from e + return self diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 1245a07ba..396f1e48c 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -23,6 +23,7 @@ authorized, metrics, tools, + a2a, ) # noqa:E402 @@ -84,6 +85,7 @@ def test_include_routers() -> None: assert conversations_v2.router in app.get_routers() assert conversations_v3.router in app.get_routers() assert metrics.router in app.get_routers() + assert a2a.router in app.get_routers() def test_check_prefixes() -> None: @@ -112,3 +114,4 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(conversations_v2.router) == "/v2" assert app.get_router_prefix(conversations_v3.router) == "/v1" assert app.get_router_prefix(metrics.router) == "" + assert app.get_router_prefix(a2a.router) == "" diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 38177a8a7..e6ce9bced 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -102,6 +102,7 @@ def test_dump_configuration(tmp_path: Path) -> None: "service": { "host": "localhost", "port": 8080, + "base_url": None, "auth_enabled": False, "workers": 1, "color_log": True, @@ -402,6 +403,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: "service": { "host": "localhost", "port": 8080, + "base_url": None, "auth_enabled": False, "workers": 1, "color_log": True, From ceda9f1cd34082f4cb3a56fff1d60a7b299b321f Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 26 Nov 2025 19:09:11 +0100 Subject: [PATCH 2/6] Add documentation about A2A addition --- docs/a2a_protocol.md | 614 ++++++++++++++++++++++++++++++++++++ src/app/endpoints/README.md | 8 + 2 files changed, 622 insertions(+) create mode 100644 docs/a2a_protocol.md diff --git a/docs/a2a_protocol.md b/docs/a2a_protocol.md new file mode 100644 index 000000000..8f1ba7a34 --- /dev/null +++ b/docs/a2a_protocol.md @@ -0,0 +1,614 @@ +# A2A (Agent-to-Agent) Protocol Integration + +This document describes the A2A (Agent-to-Agent) protocol implementation in Lightspeed Core Stack, which enables standardized communication between AI agents. + +## Overview + +The A2A protocol is an open standard for agent-to-agent communication that allows different AI agents to discover, communicate, and collaborate with each other. Lightspeed Core Stack implements the A2A protocol to expose its AI capabilities to other agents and systems. + +### Key Concepts + +- **Agent Card**: A JSON document that describes an agent's capabilities, skills, and how to interact with it +- **Task**: A unit of work that an agent can execute, with states like `submitted`, `working`, `completed`, `failed`, `input_required` +- **Message**: Communication between agents containing text or other content parts +- **Artifact**: Output produced by an agent during task execution + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ A2A Client │ +│ (A2A Inspector, Other Agents) │ +└─────────────────────────┬───────────────────────────────────────┘ + │ JSON-RPC over HTTP + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ A2A Endpoints │ │ +│ │ /.well-known/agent.json - Agent Card Discovery │ │ +│ │ /a2a - JSON-RPC Handler │ │ +│ │ /a2a/health - Health Check │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ LightspeedAgentExecutor │ │ +│ │ - Handles task execution │ │ +│ │ - Converts Llama Stack events to A2A events │ │ +│ │ - Manages multi-turn conversations │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Llama Stack Client │ │ +│ │ - Agent API (streaming turns) │ │ +│ │ - Tools, Shields, RAG integration │ │ +│ └──────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Endpoints + +### Agent Card Discovery + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/.well-known/agent.json` | GET | Returns the agent card (standard A2A discovery path) | +| `/.well-known/agent-card.json` | GET | Returns the agent card (alternate path) | + +### A2A JSON-RPC + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/a2a` | POST | Main JSON-RPC endpoint for A2A protocol | +| `/a2a` | GET | Agent card retrieval via GET | +| `/a2a/health` | GET | Health check endpoint | + +### Responses API Variant (Optional) + +If you want to use the Responses API backend instead of the Agent API: + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/responses/.well-known/agent.json` | GET | Agent card for Responses API backend | +| `/responses/.well-known/agent-card.json` | GET | Agent card (alternate path) | +| `/responses/a2a` | POST | JSON-RPC endpoint using Responses API | +| `/responses/a2a/health` | GET | Health check endpoint | + +## Configuration + +### Agent Card Configuration + +The agent card is configured via the `customization.agent_card_config` section in your configuration file: + +```yaml +customization: + agent_card_config: + name: "My AI Assistant" + description: "An AI assistant for helping with various tasks" + provider: + organization: "My Organization" + url: "https://myorg.example.com" + skills: + - id: "general-qa" + name: "General Q&A" + description: "Answer general questions about various topics" + tags: ["qa", "general"] + inputModes: ["text/plain"] + outputModes: ["text/plain"] + examples: + - "What is the capital of France?" + - "Explain how photosynthesis works" + - id: "code-assistance" + name: "Code Assistance" + description: "Help with coding questions and debugging" + tags: ["coding", "development"] + inputModes: ["text/plain"] + outputModes: ["text/plain"] + capabilities: + streaming: true + pushNotifications: false + stateTransitionHistory: false + defaultInputModes: ["text/plain"] + defaultOutputModes: ["text/plain"] + security: + - bearer: [] + security_schemes: + bearer: + type: http + scheme: bearer +``` + +### Service Base URL + +The agent card URL is constructed from the service configuration: + +```yaml +service: + base_url: "https://my-lightspeed-service.example.com" +``` + +If `base_url` is not set, it defaults to `http://localhost:8080`. Note that the actual port depends on your service configuration (e.g., `8090` if configured differently). + +### Authentication + +A2A endpoints require authentication. Configure authentication as described in [auth.md](auth.md): + +```yaml +authentication: + module: jwk # or k8s, noop + jwk_config: + url: "https://auth.example.com/.well-known/jwks.json" +``` + +### Authorization + +The A2A endpoint uses the `A2A_JSONRPC` action. Configure access rules: + +```yaml +authorization: + access_rules: + - role: "user" + actions: + - A2A_JSONRPC +``` + +## Agent Card Structure + +The agent card describes the agent's capabilities: + +```json +{ + "name": "Lightspeed AI Assistant", + "description": "AI assistant for OpenShift and Kubernetes", + "version": "1.0.0", + "url": "https://example.com/a2a", + "documentation_url": "https://example.com/docs", + "protocol_version": "0.2.1", + "provider": { + "organization": "Red Hat", + "url": "https://redhat.com" + }, + "skills": [ + { + "id": "openshift-qa", + "name": "OpenShift Q&A", + "description": "Answer questions about OpenShift", + "tags": ["openshift", "kubernetes"], + "input_modes": ["text/plain"], + "output_modes": ["text/plain"] + } + ], + "capabilities": { + "streaming": true, + "push_notifications": false, + "state_transition_history": false + }, + "default_input_modes": ["text/plain"], + "default_output_modes": ["text/plain"], + "security": [{"bearer": []}], + "security_schemes": { + "bearer": { + "type": "http", + "scheme": "bearer" + } + } +} +``` + +## How the Executor Works + +### LightspeedAgentExecutor + +The `LightspeedAgentExecutor` class implements the A2A `AgentExecutor` interface: + +1. **Receives A2A Request**: Extracts user input from the A2A message +2. **Creates Query Request**: Builds an internal `QueryRequest` with conversation context +3. **Calls Llama Stack**: Uses the Agent API to get streaming responses +4. **Converts Events**: Transforms Llama Stack streaming chunks to A2A events +5. **Manages State**: Tracks task state and publishes status updates + +### Event Flow + +``` +A2A Request + │ + ▼ +┌─────────────────────┐ +│ Extract User Input │ +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Create/Resume Task │──► TaskSubmittedEvent +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Call Llama Stack │──► TaskStatusUpdateEvent (working) +│ Agent API │ +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Stream Response │──► TaskStatusUpdateEvent (working, with deltas) +│ Chunks │──► TaskStatusUpdateEvent (tool calls) +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Turn Complete │──► TaskArtifactUpdateEvent (final content) +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Finalize Task │──► TaskStatusUpdateEvent (completed/failed) +└─────────────────────┘ +``` + +### Task States + +| State | Description | +|-------|-------------| +| `submitted` | Task has been received and queued | +| `working` | Task is being processed | +| `completed` | Task finished successfully | +| `failed` | Task failed with an error | +| `input_required` | Agent needs additional input from the user | +| `auth_required` | Authentication is required to continue | + +### Multi-Turn Conversations + +The A2A implementation supports multi-turn conversations: + +1. Each A2A `contextId` maps to a Llama Stack `conversation_id` +2. The mapping is stored in memory (`_CONTEXT_TO_CONVERSATION`) +3. Subsequent messages with the same `contextId` continue the conversation +4. Conversation history is preserved across turns + +## Testing with A2A Inspector + +[A2A Inspector](https://github.com/a2aproject/a2a-inspector) is a tool for inspecting, debugging, and validating A2A agents. + +### Prerequisites + +1. Start your Lightspeed service: + ```bash + uv run python -m runners.uvicorn + ``` + +2. Ensure the service is accessible (e.g., `http://localhost:8090`) + +### Installing A2A Inspector + +**Requirements:** Python 3.10+, uv, Node.js, and npm + +1. **Clone the repository**: + ```bash + git clone https://github.com/a2aproject/a2a-inspector.git + cd a2a-inspector + ``` + +2. **Install dependencies**: + ```bash + # Python dependencies + uv sync + + # Node.js dependencies + cd frontend + npm install + cd .. + ``` + +3. **Run the inspector**: + + **Option A - Local Development:** + ```bash + chmod +x scripts/run.sh # First time only + bash scripts/run.sh + ``` + Access at: `http://127.0.0.1:5001` + + **Option B - Docker:** + ```bash + docker build -t a2a-inspector . + docker run -d -p 8080:8080 a2a-inspector + ``` + Access at: `http://127.0.0.1:8080` + +### Using A2A Inspector + +1. **Connect to Agent**: + - Open the inspector UI in your browser + - Enter the agent card URL: `http://localhost:/.well-known/agent.json` (e.g., `http://localhost:8090/.well-known/agent.json`) + - If authentication is required, configure the bearer token + +2. **Discover Agent**: + - The inspector will fetch and display the agent card + - You'll see the agent's skills and capabilities + +3. **Send Messages**: + - Use the message input to send queries + - For streaming, select "Stream" mode + - Watch real-time status updates and responses + +### Example: Testing with curl + +> **Note:** The examples below use port `8090`. Adjust to match your configured service port. + +#### 1. Fetch Agent Card + +```bash +curl -H "Authorization: Bearer $TOKEN" \ + http://localhost:8090/.well-known/agent.json +``` + +#### 2. Send a Message (Non-Streaming) + +```bash +curl -X POST http://localhost:8090/a2a \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "msg-001", + "role": "user", + "parts": [ + {"type": "text", "text": "What is Kubernetes?"} + ] + } + } + }' +``` + +#### 3. Stream a Message + +```bash +curl -X POST http://localhost:8090/a2a \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -H "Accept: text/event-stream" \ + -d '{ + "jsonrpc": "2.0", + "id": "1", + "method": "message/stream", + "params": { + "message": { + "messageId": "msg-001", + "role": "user", + "parts": [ + {"type": "text", "text": "Explain pods in Kubernetes"} + ] + } + } + }' +``` + +#### 4. Continue a Conversation + +Use the `contextId` from a previous response: + +```bash +curl -X POST http://localhost:8090/a2a \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "2", + "method": "message/send", + "params": { + "message": { + "messageId": "msg-002", + "role": "user", + "parts": [ + {"type": "text", "text": "How do I create one?"} + ] + }, + "contextId": "previous-context-id-here" + } + }' +``` + +### Example: Python Client + +```python +import httpx +import json + +BASE_URL = "http://localhost:8090" +TOKEN = "your-bearer-token" + +headers = { + "Authorization": f"Bearer {TOKEN}", + "Content-Type": "application/json", +} + +# Fetch agent card +response = httpx.get( + f"{BASE_URL}/.well-known/agent.json", + headers=headers +) +agent_card = response.json() +print(f"Agent: {agent_card['name']}") + +# Send a message +payload = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "msg-001", + "role": "user", + "parts": [{"type": "text", "text": "Hello, what can you do?"}] + } + } +} + +response = httpx.post( + f"{BASE_URL}/a2a", + headers=headers, + json=payload +) +result = response.json() +print(json.dumps(result, indent=2)) +``` + +### Example: Streaming with Python + +```python +import httpx +import json + +BASE_URL = "http://localhost:8090" +TOKEN = "your-bearer-token" + +headers = { + "Authorization": f"Bearer {TOKEN}", + "Content-Type": "application/json", + "Accept": "text/event-stream", +} + +payload = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/stream", + "params": { + "message": { + "messageId": "msg-001", + "role": "user", + "parts": [{"type": "text", "text": "Explain Kubernetes architecture"}] + } + } +} + +with httpx.stream( + "POST", + f"{BASE_URL}/a2a", + headers=headers, + json=payload, + timeout=300.0 +) as response: + for line in response.iter_lines(): + if line.startswith("data:"): + data = json.loads(line[5:]) + result = data.get("result", {}) + event_kind = result.get("kind") + if event_kind == "status-update": + status = result.get("status", {}) + state = status.get("state") + message = status.get("message", {}) + text = "" + for part in message.get("parts", []): + if part.get("kind") == "text": + text += part.get("text", "") + if text: + print(text, end="", flush=True) + elif event_kind == "artifact-update": + artifact = result.get("artifact", {}) + for part in artifact.get("parts", []): + if part.get("kind") == "text": + print(part.get("text", "")) +``` + +## Status Update Handling + +### How Status Updates Work + +During task execution, the agent sends status updates via `TaskStatusUpdateEvent`: + +1. **Initial Status**: When a task starts, a `working` status is sent with metadata (model, conversation_id) + +2. **Text Deltas**: As the LLM generates text, each token/chunk is sent as a `working` status with the delta text in the message + +3. **Tool Calls**: When the agent calls tools (RAG, MCP servers), status updates indicate the tool being called + +4. **Final Status**: When complete, a `completed` or `failed` status is sent + +### TaskResultAggregator + +The `TaskResultAggregator` class tracks the overall task state: + +- Collects status updates during streaming +- Determines the final task state based on priority: + 1. `failed` (highest priority) + 2. `auth_required` + 3. `input_required` + 4. `working` (default during processing) +- Ensures intermediate updates show `working` state to prevent premature client termination + +### Example Status Update Flow + +Each SSE event is wrapped in a JSON-RPC response with `id`, `jsonrpc`, and `result` fields. The `result.kind` field indicates the event type: + +```json +// 1. Task submitted (kind: "task") +{"id":"1","jsonrpc":"2.0","result":{"contextId":"ctx-1","id":"task-1","kind":"task","status":{"state":"submitted"}}} + +// 2. Working with metadata (kind: "status-update") +{"id":"1","jsonrpc":"2.0","result":{"contextId":"ctx-1","kind":"status-update","metadata":{"model":"llama3.1"},"status":{"state":"working"},"taskId":"task-1"}} + +// 3. Tool call notification +{"id":"1","jsonrpc":"2.0","result":{"contextId":"ctx-1","kind":"status-update","status":{"message":{"kind":"message","messageId":"msg-1","parts":[{"kind":"text","text":"Calling tool: my_tool"}],"role":"agent"},"state":"working"},"taskId":"task-1"}} + +// 4. Text streaming (multiple events with text chunks) +{"id":"1","jsonrpc":"2.0","result":{"contextId":"ctx-1","kind":"status-update","status":{"message":{"kind":"message","messageId":"msg-2","parts":[{"kind":"text","text":"Hello"}],"role":"agent"},"state":"working"},"taskId":"task-1"}} + +{"id":"1","jsonrpc":"2.0","result":{"contextId":"ctx-1","kind":"status-update","status":{"message":{"kind":"message","messageId":"msg-3","parts":[{"kind":"text","text":" world!"}],"role":"agent"},"state":"working"},"taskId":"task-1"}} + +// 5. Final artifact (kind: "artifact-update", complete response) +{"id":"1","jsonrpc":"2.0","result":{"artifact":{"artifactId":"art-1","parts":[{"kind":"text","text":"Hello world!"}]},"contextId":"ctx-1","kind":"artifact-update","lastChunk":true,"taskId":"task-1"}} + +// 6. Completion (final: true) +{"id":"1","jsonrpc":"2.0","result":{"contextId":"ctx-1","final":true,"kind":"status-update","status":{"state":"completed"},"taskId":"task-1"}} +``` + +## Troubleshooting + +### Common Issues + +1. **Agent Card Not Found (404)** + - Ensure `agent_card_config` is configured in your YAML + - Check that the service is running and accessible + +2. **Authentication Failed (401)** + - Verify your bearer token is valid + - Check authentication configuration + +3. **Authorization Failed (403)** + - Ensure your role has `A2A_JSONRPC` action permission + - Check authorization rules in configuration + +4. **Connection Timeout** + - Streaming responses have a 300-second timeout + - Check network connectivity to Llama Stack + +5. **No Response from Agent** + - Verify Llama Stack is running and accessible + - Check logs for errors in the executor + +### Debug Logging + +Enable debug logging to see detailed A2A processing: + +```yaml +service: + color_log: true +``` + +Check logs for entries from `app.endpoints.handlers` logger. + +## Protocol Version + +This implementation supports A2A protocol version **0.2.1**. + +## References + +- [A2A Protocol Specification](https://github.com/google/A2A) +- [Llama Stack Documentation](https://llama-stack.readthedocs.io/) +- [FastAPI Documentation](https://fastapi.tiangolo.com/) diff --git a/src/app/endpoints/README.md b/src/app/endpoints/README.md index cbbad003f..e11ead10f 100644 --- a/src/app/endpoints/README.md +++ b/src/app/endpoints/README.md @@ -3,6 +3,10 @@ ## [__init__.py](__init__.py) Implementation of all endpoints. +## [a2a.py](a2a.py) +Handler for A2A (Agent-to-Agent) protocol endpoints using Agent API. +See [A2A Protocol Documentation](../../../docs/a2a_protocol.md) for details. + ## [authorized.py](authorized.py) Handler for REST API call to authorized endpoint. @@ -54,6 +58,10 @@ Handler for REST API call to provide answer to streaming query. ## [streaming_query_v2.py](streaming_query_v2.py) Streaming query handler using Responses API (v2). +## [responses_a2a.py](responses_a2a.py) +Handler for A2A (Agent-to-Agent) protocol endpoints using Responses API. +See [A2A Protocol Documentation](../../../docs/a2a_protocol.md) for details. + ## [tools.py](tools.py) Handler for REST API call to list available tools from MCP servers. From 69f67851343283cfcc8eff630609bf14d518938d Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Tue, 25 Nov 2025 22:39:04 +0100 Subject: [PATCH 3/6] Ensure A2A implementation uses responses API instead of deprecated Agent API This patch ensures that A2A wrapper only supports responses API, given Agent API is deprecated. In addition, it reorganized the A2A implementation to match the ADK implementation with regards to events updates. Now using StatusUpdates instead of ArtifactsUpdates for the intermediate streamming --- docs/a2a_protocol.md | 35 +- pyproject.toml | 2 +- src/app/endpoints/README.md | 7 +- src/app/endpoints/a2a.py | 790 +++++++++--------- src/app/routers.py | 2 +- tests/unit/app/endpoints/test_a2a.py | 659 +++++++++++++++ .../models/config/test_dump_configuration.py | 1 + uv.lock | 48 ++ 8 files changed, 1126 insertions(+), 418 deletions(-) create mode 100644 tests/unit/app/endpoints/test_a2a.py diff --git a/docs/a2a_protocol.md b/docs/a2a_protocol.md index 8f1ba7a34..c45689ff9 100644 --- a/docs/a2a_protocol.md +++ b/docs/a2a_protocol.md @@ -33,16 +33,16 @@ The A2A protocol is an open standard for agent-to-agent communication that allow │ │ │ │ ▼ │ │ ┌──────────────────────────────────────────────────────────┐ │ -│ │ LightspeedAgentExecutor │ │ +│ │ A2AAgentExecutor │ │ │ │ - Handles task execution │ │ -│ │ - Converts Llama Stack events to A2A events │ │ +│ │ - Converts Responses API events to A2A events │ │ │ │ - Manages multi-turn conversations │ │ │ └──────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌──────────────────────────────────────────────────────────┐ │ │ │ Llama Stack Client │ │ -│ │ - Agent API (streaming turns) │ │ +│ │ - Responses API (streaming responses) │ │ │ │ - Tools, Shields, RAG integration │ │ │ └──────────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────────┘ @@ -65,17 +65,6 @@ The A2A protocol is an open standard for agent-to-agent communication that allow | `/a2a` | GET | Agent card retrieval via GET | | `/a2a/health` | GET | Health check endpoint | -### Responses API Variant (Optional) - -If you want to use the Responses API backend instead of the Agent API: - -| Endpoint | Method | Description | -|----------|--------|-------------| -| `/responses/.well-known/agent.json` | GET | Agent card for Responses API backend | -| `/responses/.well-known/agent-card.json` | GET | Agent card (alternate path) | -| `/responses/a2a` | POST | JSON-RPC endpoint using Responses API | -| `/responses/a2a/health` | GET | Health check endpoint | - ## Configuration ### Agent Card Configuration @@ -199,14 +188,14 @@ The agent card describes the agent's capabilities: ## How the Executor Works -### LightspeedAgentExecutor +### A2AAgentExecutor -The `LightspeedAgentExecutor` class implements the A2A `AgentExecutor` interface: +The `A2AAgentExecutor` class implements the A2A `AgentExecutor` interface: 1. **Receives A2A Request**: Extracts user input from the A2A message 2. **Creates Query Request**: Builds an internal `QueryRequest` with conversation context -3. **Calls Llama Stack**: Uses the Agent API to get streaming responses -4. **Converts Events**: Transforms Llama Stack streaming chunks to A2A events +3. **Calls Llama Stack**: Uses the Responses API to get streaming responses +4. **Converts Events**: Transforms Responses API streaming chunks to A2A events 5. **Manages State**: Tracks task state and publishes status updates ### Event Flow @@ -227,7 +216,7 @@ A2A Request ▼ ┌─────────────────────┐ │ Call Llama Stack │──► TaskStatusUpdateEvent (working) -│ Agent API │ +│ Responses API │ └─────────────────────┘ │ ▼ @@ -238,7 +227,7 @@ A2A Request │ ▼ ┌─────────────────────┐ -│ Turn Complete │──► TaskArtifactUpdateEvent (final content) +│ Response Complete │──► TaskArtifactUpdateEvent (final content) └─────────────────────┘ │ ▼ @@ -404,16 +393,18 @@ curl -X POST http://localhost:8090/a2a \ "params": { "message": { "messageId": "msg-002", + "contextId": "previous-context-id-here", "role": "user", "parts": [ {"type": "text", "text": "How do I create one?"} ] - }, - "contextId": "previous-context-id-here" + } } }' ``` +> **Important:** The `contextId` must be placed inside the `message` object, not at the `params` level. This is required by the A2A protocol specification for the server to correctly identify and continue the conversation. + ### Example: Python Client ```python diff --git a/pyproject.toml b/pyproject.toml index 39a94240b..e313c5db3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "aiohttp>=3.12.14", "authlib>=1.6.0", # Used for A2A protocol support - "a2a-sdk", + "a2a-sdk>=0.3.4,<0.4.0", # OpenAPI exporter "email-validator>=2.2.0", "openai>=1.99.9", diff --git a/src/app/endpoints/README.md b/src/app/endpoints/README.md index e11ead10f..46e429b29 100644 --- a/src/app/endpoints/README.md +++ b/src/app/endpoints/README.md @@ -4,7 +4,7 @@ Implementation of all endpoints. ## [a2a.py](a2a.py) -Handler for A2A (Agent-to-Agent) protocol endpoints using Agent API. +Handler for A2A (Agent-to-Agent) protocol endpoints using Responses API. See [A2A Protocol Documentation](../../../docs/a2a_protocol.md) for details. ## [authorized.py](authorized.py) @@ -58,10 +58,5 @@ Handler for REST API call to provide answer to streaming query. ## [streaming_query_v2.py](streaming_query_v2.py) Streaming query handler using Responses API (v2). -## [responses_a2a.py](responses_a2a.py) -Handler for A2A (Agent-to-Agent) protocol endpoints using Responses API. -See [A2A Protocol Documentation](../../../docs/a2a_protocol.md) for details. - ## [tools.py](tools.py) Handler for REST API call to list available tools from MCP servers. - diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index a4df750d0..65a4834ea 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -1,13 +1,16 @@ -"""Handler for A2A (Agent-to-Agent) protocol endpoints.""" +"""Handler for A2A (Agent-to-Agent) protocol endpoints using Responses API.""" import asyncio import json import logging import uuid -from datetime import datetime -from typing import Annotated, Any +from datetime import datetime, timezone +from typing import Annotated, Any, AsyncIterator, MutableMapping -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObjectStream, +) from starlette.responses import Response, StreamingResponse from a2a.types import ( @@ -15,9 +18,13 @@ AgentSkill, AgentProvider, AgentCapabilities, + Artifact, + Message, Part, - Task, + TaskArtifactUpdateEvent, TaskState, + TaskStatus, + TaskStatusUpdateEvent, TextPart, ) from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -38,9 +45,10 @@ select_model_and_provider_id, evaluate_model_hints, ) -from app.endpoints.streaming_query import retrieve_response +from app.endpoints.streaming_query_v2 import retrieve_response from client import AsyncLlamaStackClientHolder from utils.mcp_headers import mcp_headers_dependency +from utils.responses import extract_text_from_response_output_item from version import __version__ logger = logging.getLogger("app.endpoints.handlers") @@ -60,313 +68,406 @@ _CONTEXT_TO_CONVERSATION: dict[str, str] = {} +def _convert_responses_content_to_a2a_parts(output: list[Any]) -> list[Part]: + """Convert Responses API output to A2A Parts. + + Args: + output: List of Responses API output items + + Returns: + List of A2A Part objects + """ + parts: list[Part] = [] + + for output_item in output: + text = extract_text_from_response_output_item(output_item) + if text: + parts.append(Part(root=TextPart(text=text))) + + return parts + + +class TaskResultAggregator: + """Aggregates the task status updates and provides the final task state.""" + + def __init__(self) -> None: + """Initialize the task result aggregator with default state.""" + self._task_state: TaskState = TaskState.working + self._task_status_message: Message | None = None + + def process_event( + self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Any + ) -> None: + """ + Process an event from the agent run and detect signals about the task status. + + Priority of task state (highest to lowest): + - failed + - auth_required + - input_required + - working + + Args: + event: The event to process + """ + if isinstance(event, TaskStatusUpdateEvent): + if event.status.state == TaskState.failed: + self._task_state = TaskState.failed + self._task_status_message = event.status.message + elif ( + event.status.state == TaskState.auth_required + and self._task_state != TaskState.failed + ): + self._task_state = TaskState.auth_required + self._task_status_message = event.status.message + elif ( + event.status.state == TaskState.input_required + and self._task_state not in (TaskState.failed, TaskState.auth_required) + ): + self._task_state = TaskState.input_required + self._task_status_message = event.status.message + elif self._task_state == TaskState.working: + # Keep tracking the working message/status + self._task_status_message = event.status.message + + # Ensure the stream always sees "working" state for intermediate updates + # unless it's already terminal in the event flow (which we control via + # generator). This prevents premature terminationby clients listening to the stream. + if not event.final: + event.status.state = TaskState.working + + @property + def task_state(self) -> TaskState: + """Return the current task state.""" + return self._task_state + + @property + def task_status_message(self) -> Message | None: + """Return the current task status message.""" + return self._task_status_message + + # ----------------------------- # Agent Executor Implementation # ----------------------------- -class LightspeedAgentExecutor(AgentExecutor): - """ - Lightspeed Agent Executor for OpenShift Assisted Chat Installer. +class A2AAgentExecutor(AgentExecutor): + """Agent Executor for A2A using Llama Stack Responses API. This executor implements the A2A AgentExecutor interface and handles - routing queries to the appropriate LLM backend. + routing queries to the LLM backend using the Responses API. """ def __init__( self, auth_token: str, mcp_headers: dict[str, dict[str, str]] | None = None ): - """ - Initialize the Lightspeed agent executor. + """Initialize the A2A agent executor. Args: auth_token: Authentication token for the request mcp_headers: MCP headers for context propagation """ - self.auth_token = auth_token - self.mcp_headers = mcp_headers or {} + self.auth_token: str = auth_token + self.mcp_headers: dict[str, dict[str, str]] = mcp_headers or {} async def execute( self, context: RequestContext, event_queue: EventQueue, ) -> None: - """ - Execute the agent with the given context and send results to the event queue. + """Execute the agent with the given context and send results to the event queue. Args: context: The request context containing user input and metadata event_queue: Queue for sending response events """ - # Get or create task - task = await self._prepare_task(context, event_queue) - - # Process the task with streaming - await self._process_task_streaming( - context, event_queue, task.context_id, task.id - ) - - async def _prepare_task( - self, context: RequestContext, event_queue: EventQueue - ) -> Task: - """ - Get existing task or create a new one. - - Args: - context: The request context - event_queue: Queue for sending events - - Returns: - Task object - """ - task = context.current_task - if not task: - if not context.message: - raise ValueError("No message provided in context") + if not context.message: + raise ValueError("A2A request must have a message") + + task_id = context.task_id or "" + context_id = context.context_id or "" + # for new task, create a task submitted event + if not context.current_task: + # Set context_id on message so new_task preserves it + if context_id and context.message: + logger.debug( + "Setting context_id %s on message for A2A contextId %s", + context_id, + context.message.message_id, + ) + context.message.context_id = context_id task = new_task(context.message) await event_queue.enqueue_event(task) - return task + task_id = task.id + context_id = task.context_id + task_updater = TaskUpdater(event_queue, task_id, context_id) - async def _process_task_streaming( # pylint: disable=too-many-locals,too-many-branches,too-many-statements + # Process the task with streaming + try: + await self._process_task_streaming( + context, task_updater, task_id, context_id + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Error handling A2A request: %s", e, exc_info=True) + try: + await task_updater.update_status( + TaskState.failed, + message=new_agent_text_message(str(e)), + final=True, + ) + except Exception as enqueue_error: # pylint: disable=broad-exception-caught + logger.error( + "Failed to publish failure event: %s", enqueue_error, exc_info=True + ) + + async def _process_task_streaming( # pylint: disable=too-many-locals self, context: RequestContext, - event_queue: EventQueue, - context_id: str, + task_updater: TaskUpdater, task_id: str, + context_id: str, ) -> None: - """ - Process the task with streaming updates. + """Process the task with streaming updates using Responses API. Args: context: The request context - event_queue: Queue for sending events - context_id: Context ID for the task - task_id: Task ID + task_updater: Task updater for sending events + task_id: The task ID to use for this execution + context_id: The context ID to use for this execution """ - task_updater = TaskUpdater(event_queue, task_id, context_id) + if not task_id or not context_id: + raise ValueError("Task ID and Context ID are required") - try: - # Extract user input using SDK utility - user_input = context.get_user_input() - if not user_input: - await task_updater.update_status( - TaskState.input_required, - message=new_agent_text_message( - "I didn't receive any input. " - "How can I help you with OpenShift installation?", - context_id=context_id, - task_id=task_id, - ), - final=True, - ) - return + # Extract user input using SDK utility + user_input = context.get_user_input() + if not user_input: + await task_updater.update_status( + TaskState.input_required, + message=new_agent_text_message( + "No input received. Please provide your input.", + context_id=context_id, + task_id=task_id, + ), + final=True, + ) + return + + preview = user_input[:200] + ("..." if len(user_input) > 200 else "") + logger.info("Processing A2A request: %s", preview) + + # Extract routing metadata from context + metadata = context.message.metadata if context.message else {} + model = metadata.get("model") if metadata else None + provider = metadata.get("provider") if metadata else None - preview = user_input[:200] + ("..." if len(user_input) > 200 else "") - logger.info("Processing A2A request: %s", preview) + # Resolve conversation_id from A2A contextId for multi-turn + a2a_context_id = context_id + conversation_id = _CONTEXT_TO_CONVERSATION.get(a2a_context_id) + logger.info( + "A2A contextId %s maps to conversation_id %s", + a2a_context_id, + conversation_id, + ) - # Extract routing metadata from context - metadata = context.message.metadata if context.message else {} - model = metadata.get("model") if metadata else None - provider = metadata.get("provider") if metadata else None + # Build internal query request (conversation_id may be None for first turn) + query_request = QueryRequest( + query=user_input, + conversation_id=conversation_id, + model=model, + provider=provider, + system_prompt=None, + attachments=None, + no_tools=False, + generate_topic_summary=True, + media_type=None, + ) - # Resolve conversation_id from A2A contextId to preserve multi-turn history - a2a_context_id = context_id - conversation_id_hint = _CONTEXT_TO_CONVERSATION.get(a2a_context_id) + # Get LLM client and select model + client = AsyncLlamaStackClientHolder().get_client() + llama_stack_model_id, _model_id, _provider_id = select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints(user_conversation=None, query_request=query_request), + ) + + # Stream response from LLM using the Responses API + stream, conversation_id = await retrieve_response( + client, + llama_stack_model_id, + query_request, + self.auth_token, + mcp_headers=self.mcp_headers, + ) + + # Persist conversation_id for next turn in same A2A context + if conversation_id: + _CONTEXT_TO_CONVERSATION[a2a_context_id] = conversation_id logger.info( - "A2A contextId %s maps to conversation_id %s", + "Persisted conversation_id %s for A2A contextId %s", + conversation_id, a2a_context_id, - conversation_id_hint, ) - # Build internal query request with conversation_id for history - query_request = QueryRequest( - query=user_input, - conversation_id=conversation_id_hint, - model=model, - provider=provider, + # Initialize result aggregator + aggregator = TaskResultAggregator() + event_queue = task_updater.event_queue + + # Emit working status with metadata before processing stream + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context_id, + final=False, + metadata={ + "model": llama_stack_model_id, + "conversation_id": conversation_id, + }, ) + ) - # Get LLM client and select model - client = AsyncLlamaStackClientHolder().get_client() - llama_stack_model_id, _model_id, _provider_id = ( - select_model_and_provider_id( - await client.models.list(), - *evaluate_model_hints( - user_conversation=None, query_request=query_request - ), - ) - ) + # Process stream using generator and aggregator pattern + async for a2a_event in self._convert_stream_to_events( + stream, task_id, context_id, conversation_id + ): + aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) - # Stream response from LLM with status updates - stream, conversation_id = await retrieve_response( - client, - llama_stack_model_id, - query_request, - self.auth_token, - mcp_headers=self.mcp_headers, + # Publish the final task result event + if aggregator.task_state == TaskState.working: + await task_updater.update_status( + TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + final=True, + ) + else: + await task_updater.update_status( + aggregator.task_state, + message=aggregator.task_status_message, + timestamp=datetime.now(timezone.utc).isoformat(), + final=True, ) - # Persist conversationId for next turn in same A2A context - if conversation_id: - _CONTEXT_TO_CONVERSATION[a2a_context_id] = conversation_id - logger.info( - "Persisted conversation_id %s for A2A contextId %s", - conversation_id, - a2a_context_id, - ) + async def _convert_stream_to_events( # pylint: disable=too-many-branches,too-many-locals + self, + stream: AsyncIterator[OpenAIResponseObjectStream], + task_id: str, + context_id: str, + conversation_id: str | None, + ) -> AsyncIterator[Any]: + """Convert Responses API stream chunks to A2A events. - # Stream incremental updates: emit working status with text deltas. - # Terminal conditions: - # - turn_awaiting_input -> TaskState.input_required with accumulated text - # - turn_complete -> TaskState.completed (final), leverage contextId for follow-ups - final_event_sent = False - accumulated_text_chunks: list[str] = [] - streamed_any_delta = False - - artifact_id = str(uuid.uuid4()) - async for chunk in stream: - # Extract text from chunk - llama-stack structure - if hasattr(chunk, "event") and chunk.event is not None: - payload = chunk.event.payload - event_type = payload.event_type - - # Handle turn_awaiting_input - request more input with accumulated text - if event_type == "turn_awaiting_input": - logger.debug("Turn awaiting input") - try: - final_text = ( - "" - if streamed_any_delta - else "".join(accumulated_text_chunks) - ) - await task_updater.update_status( - TaskState.input_required, - message=new_agent_text_message( - final_text, - context_id=context_id, - task_id=task_id, - ), - final=True, - ) - final_event_sent = True - logger.info("Input required for task %s", task_id) - except Exception: # pylint: disable=broad-except - logger.debug( - "Error sending input_required status", exc_info=True - ) - # End the stream for this turn after requesting input - break - - # Handle turn_complete - complete the task for this turn - elif event_type == "turn_complete": - logger.debug("Turn complete event") - try: - final_text = ( - "" - if streamed_any_delta - else "".join(accumulated_text_chunks) - ) - # await task_updater.update_status( - # TaskState.completed, - # message=new_agent_text_message( - # final_text, - # context_id=context_id, - # task_id=task_id, - # ), - # final=True, - # ) - task_metadata = { - "conversation_id": str(conversation_id), - "message_id": str(chunk.event.payload.turn.turn_id), - "sources": None - } - - await task_updater.add_artifact( - parts=[Part(root=TextPart(text=final_text))], - artifact_id=artifact_id, - metadata=task_metadata, - append=streamed_any_delta, - last_chunk=True - ) - await task_updater.complete() - final_event_sent = True - except Exception: # pylint: disable=broad-except - logger.debug( - "Error sending completed on turn_complete", - exc_info=True, - ) - logger.info("Turn completed for task %s", task_id) - # End the stream for this turn - break + Args: + stream: The Responses API response stream + task_id: The task ID for this execution + context_id: The context ID for this execution + conversation_id: The conversation ID for this A2A context - # Handle streaming inference tokens - elif event_type == "step_progress": - if hasattr(payload, "delta") and payload.delta.type == "text": - delta_text = payload.delta.text - if delta_text: - accumulated_text_chunks.append(delta_text) - logger.debug("Step progress, delta test: %s", delta_text) - # await task_updater.update_status( - # TaskState.working, - # message=new_agent_text_message( - # delta_text, - # context_id=context_id, - # task_id=task_id, - # ), - # ) - await task_updater.add_artifact( - parts=[Part(root=TextPart(text=delta_text))], - artifact_id=artifact_id, - metadata=None, - append=streamed_any_delta, - ) - streamed_any_delta = True - - # Ensure exactly one terminal status per turn - if not final_event_sent: - try: - final_text = ( - "" if streamed_any_delta else "".join(accumulated_text_chunks) - ) - # await task_updater.update_status( - # TaskState.completed, - # message=new_agent_text_message( - # final_text, - # context_id=context_id, - # task_id=task_id, - # ), - # final=True, - # ) - await task_updater.add_artifact( - parts=[Part(root=TextPart(text=final_text))], - artifact_id=artifact_id, - metadata=None, - append=streamed_any_delta, - last_chunk=True - ) - await task_updater.complete() - except Exception: # pylint: disable=broad-except - logger.debug( - "Error sending fallback completed status", exc_info=True + Yields: + A2A events (TaskStatusUpdateEvent or TaskArtifactUpdateEvent) + """ + if not task_id or not context_id: + raise ValueError("Task ID and Context ID are required") + + artifact_id = str(uuid.uuid4()) + text_parts: list[str] = [] + + async for chunk in stream: + event_type = getattr(chunk, "type", None) + + # Skip response.created - conversation is already created + if event_type == "response.created": + continue + + # Text streaming - emit as working status with text delta + if event_type == "response.output_text.delta": + delta = getattr(chunk, "delta", "") + if delta: + text_parts.append(delta) + yield TaskStatusUpdateEvent( + task_id=task_id, + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + delta, + context_id=context_id, + task_id=task_id, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context_id, + final=False, ) - except Exception as exc: # pylint: disable=broad-except - logger.error("Error executing agent: %s", str(exc), exc_info=True) - await task_updater.update_status( - TaskState.failed, - message=new_agent_text_message( - f"Sorry, I encountered an error: {str(exc)}", + # Tool call events + elif event_type == "response.function_call_arguments.done": + item_id = getattr(chunk, "item_id", "") + yield TaskStatusUpdateEvent( + task_id=task_id, + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + f"Tool call: {item_id}", + context_id=context_id, + task_id=task_id, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), context_id=context_id, + final=False, + ) + + # MCP call completion + elif event_type == "response.mcp_call.arguments.done": + item_id = getattr(chunk, "item_id", "") + yield TaskStatusUpdateEvent( task_id=task_id, - ), - final=True, - ) + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + f"MCP call: {item_id}", + context_id=context_id, + task_id=task_id, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context_id, + final=False, + ) + + # Response completed - emit final artifact + elif event_type == "response.completed": + response_obj = getattr(chunk, "response", None) + final_text = "".join(text_parts) + + if response_obj: + output = getattr(response_obj, "output", []) + a2a_parts = _convert_responses_content_to_a2a_parts(output) + if not a2a_parts and final_text: + a2a_parts = [Part(root=TextPart(text=final_text))] + else: + a2a_parts = ( + [Part(root=TextPart(text=final_text))] if final_text else [] + ) + + yield TaskArtifactUpdateEvent( + task_id=task_id, + last_chunk=True, + context_id=context_id, + artifact=Artifact( + artifact_id=artifact_id, + parts=a2a_parts, + metadata={"conversation_id": str(conversation_id or "")}, + ), + ) async def cancel( self, context: RequestContext, # pylint: disable=unused-argument event_queue: EventQueue, # pylint: disable=unused-argument ) -> None: - """ - Handle task cancellation. + """Handle task cancellation. Args: context: The request context @@ -394,155 +495,71 @@ def get_lightspeed_agent_card() -> AgentCard: """ # Get base URL from configuration or construct it service_config = configuration.service_configuration - base_url = service_config.base_url if service_config.base_url is not None else "http://localhost:8080" - - # Check if agent card is configured via file - if ( - configuration.customization is not None - and configuration.customization.agent_card_config is not None - ): - config = configuration.customization.agent_card_config - - # Parse skills from config - skills = [ - AgentSkill( - id=skill.get("id"), - name=skill.get("name"), - description=skill.get("description"), - tags=skill.get("tags", []), - input_modes=skill.get("inputModes", []), - output_modes=skill.get("outputModes", []), - examples=skill.get("examples", []), - ) - for skill in config.get("skills", []) - ] - - # Parse provider from config - provider_config = config.get("provider", {}) - provider = AgentProvider( - organization=provider_config.get("organization", ""), - url=provider_config.get("url", ""), - ) + base_url = ( + service_config.base_url + if service_config.base_url is not None + else "http://localhost:8080" + ) - # Parse capabilities from config - capabilities_config = config.get("capabilities", {}) - capabilities = AgentCapabilities( - streaming=capabilities_config.get("streaming", True), - push_notifications=capabilities_config.get("pushNotifications", False), - state_transition_history=capabilities_config.get( - "stateTransitionHistory", False - ), + if not configuration.customization: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Customization configuration not found", ) - return AgentCard( - name=config.get("name", "Lightspeed AI Assistant"), - description=config.get("description", ""), - version=__version__, - url=f"{base_url}/a2a", - documentation_url=f"{base_url}/docs", - provider=provider, - skills=skills, - default_input_modes=config.get("defaultInputModes", ["text/plain"]), - default_output_modes=config.get("defaultOutputModes", ["text/plain"]), - capabilities=capabilities, - protocol_version="0.2.1", - security=config.get("security", [{"bearer": []}]), - security_schemes=config.get("security_schemes", {}), + if not configuration.customization.agent_card_config: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Agent card configuration not found", ) - # Fallback to default hardcoded agent card - logger.info("Using default hardcoded agent card (no agent_card_path configured)") + config = configuration.customization.agent_card_config - # Define Lightspeed's skills for OpenShift cluster installation + # Parse skills from config skills = [ AgentSkill( - id="cluster_installation_guidance", - name="Cluster Installation Guidance", - description=( - "Provide guidance and assistance for OpenShift cluster " - "installation using assisted-installer" - ), - tags=["openshift", "installation", "assisted-installer"], - input_modes=["text/plain", "application/json"], - output_modes=["text/plain", "application/json"], - examples=[ - "How do I install OpenShift using assisted-installer?", - "What are the prerequisites for OpenShift installation?", - ], - ), - AgentSkill( - id="cluster_configuration_validation", - name="Cluster Configuration Validation", - description=( - "Validate and provide recommendations for OpenShift " - "cluster configuration parameters" - ), - tags=["openshift", "configuration", "validation"], - input_modes=["application/json", "text/plain"], - output_modes=["application/json", "text/plain"], - examples=[ - "Validate my cluster configuration", - "Check if my OpenShift setup meets requirements", - ], - ), - AgentSkill( - id="installation_troubleshooting", - name="Installation Troubleshooting", - description=( - "Help troubleshoot OpenShift cluster installation issues " - "and provide solutions" - ), - tags=["openshift", "troubleshooting", "support"], - input_modes=["text/plain", "application/json"], - output_modes=["text/plain", "application/json"], - examples=[ - "My cluster installation is failing", - "How do I fix installation errors?", - ], - ), - AgentSkill( - id="cluster_requirements_analysis", - name="Cluster Requirements Analysis", - description=( - "Analyze infrastructure requirements for " - "OpenShift cluster deployment" - ), - tags=["openshift", "requirements", "planning"], - input_modes=["application/json", "text/plain"], - output_modes=["application/json", "text/plain"], - examples=[ - "What hardware do I need for OpenShift?", - "Analyze requirements for a 5-node cluster", - ], - ), + id=skill.get("id"), + name=skill.get("name"), + description=skill.get("description"), + tags=skill.get("tags", []), + input_modes=skill.get("inputModes", []), + output_modes=skill.get("outputModes", []), + examples=skill.get("examples", []), + ) + for skill in config.get("skills", []) ] - # Provider information - provider = AgentProvider(organization="Red Hat", url="https://redhat.com") + # Parse provider from config + provider_config = config.get("provider", {}) + provider = AgentProvider( + organization=provider_config.get("organization", ""), + url=provider_config.get("url", ""), + ) - # Agent capabilities + # Parse capabilities from config + capabilities_config = config.get("capabilities", {}) capabilities = AgentCapabilities( - streaming=True, push_notifications=False, state_transition_history=False + streaming=capabilities_config.get("streaming", True), + push_notifications=capabilities_config.get("pushNotifications", False), + state_transition_history=capabilities_config.get( + "stateTransitionHistory", False + ), ) return AgentCard( - name="OpenShift Assisted Installer AI Assistant", - description=( - "AI-powered assistant specialized in OpenShift cluster " - "installation, configuration, and troubleshooting using " - "assisted-installer backend" - ), + name=config.get("name", "Lightspeed AI Assistant"), + description=config.get("description", ""), version=__version__, url=f"{base_url}/a2a", documentation_url=f"{base_url}/docs", provider=provider, skills=skills, - default_input_modes=["text/plain"], - default_output_modes=["text/plain"], + default_input_modes=config.get("defaultInputModes", ["text/plain"]), + default_output_modes=config.get("defaultOutputModes", ["text/plain"]), capabilities=capabilities, protocol_version="0.2.1", - security=[{"bearer": []}], - security_schemes={}, + security=config.get("security", [{"bearer": []}]), + security_schemes=config.get("security_schemes", {}), ) @@ -577,8 +594,7 @@ async def get_agent_card( # pylint: disable=unused-argument def _create_a2a_app(auth_token: str, mcp_headers: dict[str, dict[str, str]]) -> Any: - """ - Create an A2A Starlette application instance with auth context. + """Create an A2A Starlette application instance with auth context. Args: auth_token: Authentication token for the request @@ -587,9 +603,7 @@ def _create_a2a_app(auth_token: str, mcp_headers: dict[str, dict[str, str]]) -> Returns: A2A Starlette ASGI application """ - agent_executor = LightspeedAgentExecutor( - auth_token=auth_token, mcp_headers=mcp_headers - ) + agent_executor = A2AAgentExecutor(auth_token=auth_token, mcp_headers=mcp_headers) request_handler = DefaultRequestHandler( agent_executor=agent_executor, @@ -612,7 +626,7 @@ async def handle_a2a_jsonrpc( # pylint: disable=too-many-locals,too-many-statem mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> Response | StreamingResponse: """ - Main A2A JSON-RPC endpoint following the A2A protocol specification. + Handle A2A JSON-RPC requests following the A2A protocol specification. This endpoint uses the DefaultRequestHandler from the A2A SDK to handle all JSON-RPC requests including message/send, message/stream, etc. @@ -670,13 +684,13 @@ async def handle_a2a_jsonrpc( # pylint: disable=too-many-locals,too-many-statem logger.error("Error detecting streaming request: %s", str(e)) # Setup scope for A2A app - scope = request.scope.copy() + scope = dict(request.scope) scope["path"] = "/" # A2A app expects root path # We need to re-provide the body since we already read it body_sent = False - async def receive(): + async def receive() -> MutableMapping[str, Any]: nonlocal body_sent if not body_sent: body_sent = True @@ -720,8 +734,8 @@ async def run_a2a_app() -> None: # Start the A2A app task app_task = asyncio.create_task(run_a2a_app()) - async def response_generator() -> Any: - """Generator that yields chunks from the queue.""" + async def response_generator() -> AsyncIterator[bytes]: + """Generate chunks from the queue for streaming response.""" chunk_count = 0 try: while True: diff --git a/src/app/routers.py b/src/app/routers.py index 77ebde10a..4a3f06782 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -56,5 +56,5 @@ def include_routers(app: FastAPI) -> None: app.include_router(authorized.router) app.include_router(metrics.router) - # A2A (Agent-to-Agent) protocol endpoints + # A2A (Agent-to-Agent) protocol endpoint app.include_router(a2a.router) diff --git a/tests/unit/app/endpoints/test_a2a.py b/tests/unit/app/endpoints/test_a2a.py new file mode 100644 index 000000000..1e3199782 --- /dev/null +++ b/tests/unit/app/endpoints/test_a2a.py @@ -0,0 +1,659 @@ +"""Unit tests for the A2A (Agent-to-Agent) protocol endpoints.""" + +# pylint: disable=redefined-outer-name +# pylint: disable=protected-access + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import HTTPException, Request +from pytest_mock import MockerFixture + +from a2a.types import ( + AgentCard, + Artifact, + Part, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue +from a2a.utils import new_agent_text_message + +from app.endpoints.a2a import ( + _convert_responses_content_to_a2a_parts, + get_lightspeed_agent_card, + A2AAgentExecutor, + TaskResultAggregator, + _CONTEXT_TO_CONVERSATION, + _TASK_STORE, + a2a_health_check, + get_agent_card, +) +from configuration import AppConfig +from models.config import Action + + +# User ID must be proper UUID +MOCK_AUTH = ( + "00000001-0001-0001-0001-000000000001", + "mock_username", + False, + "mock_token", +) + + +@pytest.fixture +def dummy_request() -> Request: + """Dummy request fixture for testing.""" + req = Request( + scope={ + "type": "http", + } + ) + req.state.authorized_actions = set(Action) + return req + + +@pytest.fixture(name="setup_configuration") +def setup_configuration_fixture(mocker: MockerFixture) -> AppConfig: + """Set up configuration for tests.""" + config_dict: dict[Any, Any] = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "base_url": "http://localhost:8080", + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "mcp_servers": [], + "customization": { + "agent_card_config": { + "name": "Test Agent", + "description": "A test agent", + "provider": { + "organization": "Test Org", + "url": "https://test.org", + }, + "skills": [ + { + "id": "test-skill", + "name": "Test Skill", + "description": "A test skill", + "tags": ["test"], + "inputModes": ["text/plain"], + "outputModes": ["text/plain"], + } + ], + "capabilities": { + "streaming": True, + "pushNotifications": False, + "stateTransitionHistory": False, + }, + } + }, + "authentication": {"module": "noop"}, + "authorization": {"access_rules": []}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + mocker.patch("app.endpoints.a2a.configuration", cfg) + return cfg + + +@pytest.fixture(name="setup_minimal_configuration") +def setup_minimal_configuration_fixture(mocker: MockerFixture) -> AppConfig: + """Set up minimal configuration without agent_card_config.""" + config_dict: dict[Any, Any] = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "mcp_servers": [], + "customization": {}, # Empty customization, no agent_card_config + "authentication": {"module": "noop"}, + "authorization": {"access_rules": []}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + mocker.patch("app.endpoints.a2a.configuration", cfg) + return cfg + + +# ----------------------------- +# Tests for _convert_responses_content_to_a2a_parts +# ----------------------------- +class TestConvertResponsesContentToA2AParts: + """Tests for the content conversion function.""" + + def test_convert_empty_output(self, mocker: MockerFixture) -> None: + """Test converting empty output returns empty list.""" + mocker.patch( + "app.endpoints.a2a.extract_text_from_response_output_item", + return_value=None, + ) + result = _convert_responses_content_to_a2a_parts([]) + assert not result + + def test_convert_single_output_item(self, mocker: MockerFixture) -> None: + """Test converting single output item with text.""" + mocker.patch( + "app.endpoints.a2a.extract_text_from_response_output_item", + return_value="Hello, world!", + ) + mock_output_item = MagicMock() + result = _convert_responses_content_to_a2a_parts([mock_output_item]) + assert len(result) == 1 + assert result[0].root.text == "Hello, world!" + + def test_convert_multiple_output_items(self, mocker: MockerFixture) -> None: + """Test converting multiple output items.""" + extract_mock = mocker.patch( + "app.endpoints.a2a.extract_text_from_response_output_item", + ) + extract_mock.side_effect = ["First", "Second"] + + mock_item1 = MagicMock() + mock_item2 = MagicMock() + + result = _convert_responses_content_to_a2a_parts([mock_item1, mock_item2]) + assert len(result) == 2 + assert result[0].root.text == "First" + assert result[1].root.text == "Second" + + def test_convert_output_items_with_none_text(self, mocker: MockerFixture) -> None: + """Test that output items with no text are filtered out.""" + extract_mock = mocker.patch( + "app.endpoints.a2a.extract_text_from_response_output_item", + ) + extract_mock.side_effect = ["Valid text", None, "Another valid"] + + mock_items = [MagicMock(), MagicMock(), MagicMock()] + + result = _convert_responses_content_to_a2a_parts(mock_items) + assert len(result) == 2 + assert result[0].root.text == "Valid text" + assert result[1].root.text == "Another valid" + + +# ----------------------------- +# Tests for TaskResultAggregator +# ----------------------------- +class TestTaskResultAggregator: + """Tests for the TaskResultAggregator class.""" + + def test_initial_state_is_working(self) -> None: + """Test that initial state is working.""" + aggregator = TaskResultAggregator() + assert aggregator.task_state == TaskState.working + assert aggregator.task_status_message is None + + def test_process_working_event(self) -> None: + """Test processing a working status event.""" + aggregator = TaskResultAggregator() + message = new_agent_text_message("Processing...") + event = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.working, message=message), + final=False, + ) + + aggregator.process_event(event) + + assert aggregator.task_state == TaskState.working + assert aggregator.task_status_message == message + + def test_process_failed_event_takes_priority(self) -> None: + """Test that failed state takes priority.""" + aggregator = TaskResultAggregator() + + # First, set to input_required + event1 = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.input_required), + final=False, + ) + aggregator.process_event(event1) + + # Then set to failed + failed_message = new_agent_text_message("Error occurred") + event2 = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + aggregator.process_event(event2) + + assert aggregator.task_state == TaskState.failed + assert aggregator.task_status_message == failed_message + + def test_process_auth_required_event(self) -> None: + """Test processing auth_required status event.""" + aggregator = TaskResultAggregator() + + event = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.auth_required), + final=False, + ) + aggregator.process_event(event) + + assert aggregator.task_state == TaskState.auth_required + + def test_process_input_required_event(self) -> None: + """Test processing input_required status event.""" + aggregator = TaskResultAggregator() + + event = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.input_required), + final=False, + ) + aggregator.process_event(event) + + assert aggregator.task_state == TaskState.input_required + + def test_failed_cannot_be_overridden(self) -> None: + """Test that failed state cannot be overridden by other states.""" + aggregator = TaskResultAggregator() + + # Set to failed first + event1 = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.failed), + final=False, + ) + aggregator.process_event(event1) + + # Try to set to working + event2 = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.working), + final=False, + ) + aggregator.process_event(event2) + + # Failed should still be the state + assert aggregator.task_state == TaskState.failed + + def test_non_final_events_show_working(self) -> None: + """Test that non-final events are set to working state.""" + aggregator = TaskResultAggregator() + + event = TaskStatusUpdateEvent( + task_id="task-1", + context_id="ctx-1", + status=TaskStatus(state=TaskState.input_required), + final=False, + ) + aggregator.process_event(event) + + # The event's state should be changed to working for streaming + assert event.status.state == TaskState.working + + def test_ignores_non_status_events(self) -> None: + """Test that non-status events are ignored.""" + aggregator = TaskResultAggregator() + + # Process an artifact event + artifact_event = TaskArtifactUpdateEvent( + task_id="task-1", + context_id="ctx-1", + artifact=Artifact( + artifact_id="art-1", + parts=[Part(root=TextPart(text="Result"))], + ), + last_chunk=True, + ) + aggregator.process_event(artifact_event) + + # State should still be working + assert aggregator.task_state == TaskState.working + + +# ----------------------------- +# Tests for get_lightspeed_agent_card +# ----------------------------- +class TestGetLightspeedAgentCard: + """Tests for the agent card generation.""" + + def test_get_agent_card_with_config( + self, setup_configuration: AppConfig # pylint: disable=unused-argument + ) -> None: + """Test getting agent card with full configuration.""" + agent_card = get_lightspeed_agent_card() + + assert agent_card.name == "Test Agent" + assert agent_card.description == "A test agent" + assert agent_card.url == "http://localhost:8080/a2a" + assert agent_card.protocol_version == "0.2.1" + + # Check provider + assert agent_card.provider is not None + assert agent_card.provider.organization == "Test Org" + + # Check skills + assert len(agent_card.skills) == 1 + assert agent_card.skills[0].id == "test-skill" + assert agent_card.skills[0].name == "Test Skill" + + # Check capabilities + assert agent_card.capabilities is not None + assert agent_card.capabilities.streaming is True + + def test_get_agent_card_without_config_raises_error( + self, + setup_minimal_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test that getting agent card without config raises HTTPException.""" + with pytest.raises(HTTPException) as exc_info: + get_lightspeed_agent_card() + assert exc_info.value.status_code == 500 + assert "Agent card configuration not found" in exc_info.value.detail + + +# ----------------------------- +# Tests for A2AAgentExecutor +# ----------------------------- +class TestA2AAgentExecutor: + """Tests for the A2AAgentExecutor class.""" + + def test_executor_initialization(self) -> None: + """Test executor initialization.""" + executor = A2AAgentExecutor( + auth_token="test-token", + mcp_headers={"server1": {"header1": "value1"}}, + ) + + assert executor.auth_token == "test-token" + assert executor.mcp_headers == {"server1": {"header1": "value1"}} + + def test_executor_initialization_default_mcp_headers(self) -> None: + """Test executor initialization with default mcp_headers.""" + executor = A2AAgentExecutor(auth_token="test-token") + + assert executor.auth_token == "test-token" + assert executor.mcp_headers == {} + + @pytest.mark.asyncio + async def test_execute_without_message_raises_error(self) -> None: + """Test that execute raises error when message is missing.""" + executor = A2AAgentExecutor(auth_token="test-token") + + context = MagicMock(spec=RequestContext) + context.message = None + + event_queue = AsyncMock(spec=EventQueue) + + with pytest.raises(ValueError, match="A2A request must have a message"): + await executor.execute(context, event_queue) + + @pytest.mark.asyncio + async def test_execute_creates_new_task( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test that execute creates a new task when current_task is None.""" + executor = A2AAgentExecutor(auth_token="test-token") + + # Mock the context with a mock message + mock_message = MagicMock() + mock_message.role = "user" + mock_message.parts = [Part(root=TextPart(text="Hello"))] + mock_message.metadata = {} + + context = MagicMock(spec=RequestContext) + context.message = mock_message + context.current_task = None + context.task_id = None + context.context_id = None + context.get_user_input.return_value = "Hello" + + # Mock event queue + event_queue = AsyncMock(spec=EventQueue) + + # Mock new_task to return a mock Task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mocker.patch("app.endpoints.a2a.new_task", return_value=mock_task) + + # Mock the streaming process to avoid actual LLM calls + mocker.patch.object( + executor, + "_process_task_streaming", + new_callable=AsyncMock, + ) + + await executor.execute(context, event_queue) + + # Verify a task was created and enqueued + assert event_queue.enqueue_event.called + + @pytest.mark.asyncio + async def test_execute_passes_task_ids_to_streaming( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test that execute passes computed task_id and context_id to _process_task_streaming. + + This test verifies the fix for the issue where task_id and context_id + were computed locally in execute() but not stored in the context object, + causing _process_task_streaming to fail when trying to read them from context. + """ + executor = A2AAgentExecutor(auth_token="test-token") + + # Mock the context with empty task_id and context_id (first-turn scenario) + mock_message = MagicMock() + mock_message.role = "user" + mock_message.parts = [Part(root=TextPart(text="Hello"))] + mock_message.metadata = {} + + context = MagicMock(spec=RequestContext) + context.message = mock_message + context.current_task = None + context.task_id = None # Empty in context object + context.context_id = None # Empty in context object + context.get_user_input.return_value = "Hello" + + # Mock event queue + event_queue = AsyncMock(spec=EventQueue) + + # Mock new_task to return a task with specific IDs + mock_task = MagicMock() + mock_task.id = "computed-task-id-123" + mock_task.context_id = "computed-context-id-456" + mocker.patch("app.endpoints.a2a.new_task", return_value=mock_task) + + # Mock the streaming process + mock_process_streaming = mocker.patch.object( + executor, + "_process_task_streaming", + new_callable=AsyncMock, + ) + + await executor.execute(context, event_queue) + + # Verify _process_task_streaming was called with the computed IDs + # NOT the None values from context + mock_process_streaming.assert_called_once() + call_args = mock_process_streaming.call_args + + # Check positional arguments: context, task_updater, task_id, context_id + assert call_args[0][0] == context # First arg is context + # Third and fourth args should be the computed IDs + assert call_args[0][2] == "computed-task-id-123" # task_id + assert call_args[0][3] == "computed-context-id-456" # context_id + + @pytest.mark.asyncio + async def test_execute_handles_errors_gracefully( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test that execute handles errors and sends failure event.""" + executor = A2AAgentExecutor(auth_token="test-token") + + # Mock the context with a mock message + mock_message = MagicMock() + mock_message.role = "user" + mock_message.parts = [Part(root=TextPart(text="Hello"))] + mock_message.metadata = {} + + context = MagicMock(spec=RequestContext) + context.message = mock_message + context.current_task = MagicMock() + context.task_id = "task-123" + context.context_id = "ctx-456" + context.get_user_input.return_value = "Hello" + + # Mock event queue + event_queue = AsyncMock(spec=EventQueue) + + # Mock the streaming process to raise an error + mocker.patch.object( + executor, + "_process_task_streaming", + side_effect=Exception("Test error"), + ) + + await executor.execute(context, event_queue) + + # Verify failure event was enqueued + calls = event_queue.enqueue_event.call_args_list + # Find the failure status update + failure_sent = False + for call in calls: + event = call[0][0] + if isinstance(event, TaskStatusUpdateEvent): + if event.status.state == TaskState.failed: + failure_sent = True + break + assert failure_sent + + @pytest.mark.asyncio + async def test_process_task_streaming_no_input( + self, + mocker: MockerFixture, # pylint: disable=unused-argument + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test _process_task_streaming when no input is provided.""" + executor = A2AAgentExecutor(auth_token="test-token") + + # Mock the context with no input + mock_message = MagicMock() + mock_message.role = "user" + mock_message.parts = [] + mock_message.metadata = {} + + context = MagicMock(spec=RequestContext) + context.task_id = "task-123" + context.context_id = "ctx-456" + context.message = mock_message + context.get_user_input.return_value = "" + + # Mock event queue + event_queue = AsyncMock(spec=EventQueue) + + # Create task updater mock + task_updater = MagicMock() + task_updater.update_status = AsyncMock() + task_updater.event_queue = event_queue + + await executor._process_task_streaming( + context, task_updater, context.task_id, context.context_id + ) + + # Verify input_required status was sent + task_updater.update_status.assert_called_once() + call_args = task_updater.update_status.call_args + assert call_args[0][0] == TaskState.input_required + + @pytest.mark.asyncio + async def test_cancel_raises_not_implemented(self) -> None: + """Test that cancel raises NotImplementedError.""" + executor = A2AAgentExecutor(auth_token="test-token") + + context = MagicMock(spec=RequestContext) + event_queue = AsyncMock(spec=EventQueue) + + with pytest.raises(NotImplementedError): + await executor.cancel(context, event_queue) + + +# ----------------------------- +# Tests for context to conversation mapping +# ----------------------------- +class TestContextToConversationMapping: + """Tests for the context to conversation ID mapping.""" + + def test_context_to_conversation_is_dict(self) -> None: + """Test that _CONTEXT_TO_CONVERSATION is a dict.""" + assert isinstance(_CONTEXT_TO_CONVERSATION, dict) + + def test_task_store_exists(self) -> None: + """Test that _TASK_STORE exists.""" + assert _TASK_STORE is not None + + +# ----------------------------- +# Integration-style tests for endpoint handlers +# ----------------------------- +class TestA2AEndpointHandlers: + """Tests for A2A endpoint handler functions.""" + + @pytest.mark.asyncio + async def test_a2a_health_check(self) -> None: + """Test the health check endpoint.""" + result = await a2a_health_check() + + assert result["status"] == "healthy" + assert result["service"] == "lightspeed-a2a" + assert "version" in result + assert "a2a_sdk_version" in result + assert "timestamp" in result + + @pytest.mark.asyncio + async def test_get_agent_card_endpoint( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test the agent card endpoint.""" + # Mock authorization + mocker.patch( + "app.endpoints.a2a.authorize", + lambda action: lambda f: f, + ) + + result = await get_agent_card(auth=MOCK_AUTH) + + assert isinstance(result, AgentCard) + assert result.name == "Test Agent" + assert result.url == "http://localhost:8080/a2a" diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index e6ce9bced..53a7d8b69 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -589,6 +589,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "service": { "host": "localhost", "port": 8080, + "base_url": None, "auth_enabled": False, "workers": 1, "color_log": True, diff --git a/uv.lock b/uv.lock index 73ceafbd2..5d6701912 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,22 @@ resolution-markers = [ "python_full_version < '3.13' and sys_platform == 'darwin'", ] +[[package]] +name = "a2a-sdk" +version = "0.3.19" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "protobuf" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/74/db61ee9d2663b291a7eec03bbc7685bec72b1ceb113001350766c03f20de/a2a_sdk-0.3.19.tar.gz", hash = "sha256:ecf526d1d7781228d8680292f913bad1099ba3335a7f0ea6811543c2bd3e601d", size = 229184, upload-time = "2025-11-25T13:48:05.185Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/cd/14c1242d171b9739770be35223f1cbc1fb0244ebea2c704f8ae0d9e6abf7/a2a_sdk-0.3.19-py3-none-any.whl", hash = "sha256:314123f84524259313ec0cd9826a34bae5de769dea44b8eb9a0eca79b8935772", size = 141519, upload-time = "2025-11-25T13:48:02.622Z" }, +] + [[package]] name = "accelerate" version = "1.12.0" @@ -859,6 +875,22 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "google-api-core" +version = "2.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/da/83d7043169ac2c8c7469f0e375610d78ae2160134bf1b80634c482fa079c/google_api_core-2.28.1.tar.gz", hash = "sha256:2b405df02d68e68ce0fbc138559e6036559e685159d148ae5861013dc201baf8", size = 176759, upload-time = "2025-10-28T21:34:51.529Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl", hash = "sha256:4021b0f8ceb77a6fb4de6fde4502cecab45062e66ff4f2895169e0b35bc9466c", size = 173706, upload-time = "2025-10-28T21:34:50.151Z" }, +] + [[package]] name = "google-auth" version = "2.43.0" @@ -1346,6 +1378,7 @@ wheels = [ name = "lightspeed-stack" source = { editable = "." } dependencies = [ + { name = "a2a-sdk" }, { name = "aiohttp" }, { name = "authlib" }, { name = "cachetools" }, @@ -1359,6 +1392,7 @@ dependencies = [ { name = "openai" }, { name = "prometheus-client" }, { name = "psycopg2-binary" }, + { name = "pyyaml" }, { name = "rich" }, { name = "semver" }, { name = "sqlalchemy" }, @@ -1428,6 +1462,7 @@ llslibdev = [ [package.metadata] requires-dist = [ + { name = "a2a-sdk", specifier = ">=0.3.4,<0.4.0" }, { name = "aiohttp", specifier = ">=3.12.14" }, { name = "authlib", specifier = ">=1.6.0" }, { name = "cachetools", specifier = ">=6.1.0" }, @@ -1441,6 +1476,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.99.9" }, { name = "prometheus-client", specifier = ">=0.22.1" }, { name = "psycopg2-binary", specifier = ">=2.9.10" }, + { name = "pyyaml", specifier = ">=6.0.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "semver", specifier = "<4.0.0" }, { name = "sqlalchemy", specifier = ">=2.0.42" }, @@ -2435,6 +2471,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, ] +[[package]] +name = "proto-plus" +version = "1.26.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, +] + [[package]] name = "protobuf" version = "6.33.2" From 61a4ac7f3bfd0ec4afebc80ad8b82f2b139e126d Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 10 Dec 2025 11:06:05 +0100 Subject: [PATCH 4/6] Add support for restart and multiworker deployment Add options for task and context IDs to not only be stored in memory --- CLAUDE.md | 7 + docs/a2a_protocol.md | 122 +++++++++++- examples/lightspeed-stack-a2a-state-pg.yaml | 30 +++ .../lightspeed-stack-a2a-state-sqlite.yaml | 24 +++ pyproject.toml | 3 + src/a2a_storage/__init__.py | 23 +++ src/a2a_storage/context_store.py | 57 ++++++ src/a2a_storage/in_memory_context_store.py | 92 +++++++++ src/a2a_storage/postgres_context_store.py | 142 ++++++++++++++ src/a2a_storage/sqlite_context_store.py | 143 ++++++++++++++ src/a2a_storage/storage_factory.py | 184 ++++++++++++++++++ src/app/endpoints/a2a.py | 52 +++-- src/configuration.py | 8 + src/models/config.py | 64 ++++++ tests/unit/a2a_storage/__init__.py | 1 + .../test_in_memory_context_store.py | 95 +++++++++ .../a2a_storage/test_sqlite_context_store.py | 147 ++++++++++++++ .../unit/a2a_storage/test_storage_factory.py | 172 ++++++++++++++++ tests/unit/app/endpoints/test_a2a.py | 47 ++++- tests/unit/app/test_routers.py | 4 +- .../config/test_a2a_state_configuration.py | 103 ++++++++++ .../models/config/test_dump_configuration.py | 12 ++ uv.lock | 4 + 23 files changed, 1511 insertions(+), 25 deletions(-) create mode 100644 examples/lightspeed-stack-a2a-state-pg.yaml create mode 100644 examples/lightspeed-stack-a2a-state-sqlite.yaml create mode 100644 src/a2a_storage/__init__.py create mode 100644 src/a2a_storage/context_store.py create mode 100644 src/a2a_storage/in_memory_context_store.py create mode 100644 src/a2a_storage/postgres_context_store.py create mode 100644 src/a2a_storage/sqlite_context_store.py create mode 100644 src/a2a_storage/storage_factory.py create mode 100644 tests/unit/a2a_storage/__init__.py create mode 100644 tests/unit/a2a_storage/test_in_memory_context_store.py create mode 100644 tests/unit/a2a_storage/test_sqlite_context_store.py create mode 100644 tests/unit/a2a_storage/test_storage_factory.py create mode 100644 tests/unit/models/config/test_a2a_state_configuration.py diff --git a/CLAUDE.md b/CLAUDE.md index 1b5cf436f..d2e28e348 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -18,8 +18,15 @@ src/ ├── app/ # FastAPI application │ ├── endpoints/ # REST API endpoints │ └── main.py # Application entry point +├── a2a_storage/ # A2A protocol persistent storage +│ ├── context_store.py # Abstract base class for context stores +│ ├── in_memory_context_store.py # In-memory implementation +│ ├── sqlite_context_store.py # SQLite implementation +│ ├── postgres_context_store.py # PostgreSQL implementation +│ └── storage_factory.py # Factory for creating stores ├── auth/ # Authentication modules (k8s, jwk, noop) ├── authorization/ # Authorization middleware & resolvers +├── cache/ # Conversation cache implementations ├── models/ # Pydantic models │ ├── config.py # Configuration classes │ ├── requests.py # Request models diff --git a/docs/a2a_protocol.md b/docs/a2a_protocol.md index c45689ff9..5ac36d130 100644 --- a/docs/a2a_protocol.md +++ b/docs/a2a_protocol.md @@ -15,7 +15,7 @@ The A2A protocol is an open standard for agent-to-agent communication that allow ## Architecture -``` +```text ┌─────────────────────────────────────────────────────────────────┐ │ A2A Client │ │ (A2A Inspector, Other Agents) │ @@ -69,7 +69,59 @@ The A2A protocol is an open standard for agent-to-agent communication that allow ### Agent Card Configuration -The agent card is configured via the `customization.agent_card_config` section in your configuration file: +The agent card can be configured in two ways: + +#### Option 1: External YAML File (Recommended) + +Reference an external agent card configuration file using `customization.agent_card_path`: + +```yaml +customization: + agent_card_path: agent_card.yaml +``` + +Create a separate `agent_card.yaml` file with the agent card configuration: + +```yaml +# agent_card.yaml +name: "Lightspeed AI Assistant" +description: "An AI assistant for OpenShift and Kubernetes" +provider: + organization: "Red Hat" + url: "https://redhat.com" +skills: + - id: "openshift-qa" + name: "OpenShift Q&A" + description: "Answer questions about OpenShift and Kubernetes" + tags: ["openshift", "kubernetes", "containers"] + inputModes: ["text/plain"] + outputModes: ["text/plain"] + examples: + - "How do I create a deployment in OpenShift?" + - "What is a pod in Kubernetes?" + - id: "troubleshooting" + name: "Troubleshooting" + description: "Help diagnose and fix issues with OpenShift clusters" + tags: ["troubleshooting", "debugging", "support"] + inputModes: ["text/plain"] + outputModes: ["text/plain"] +capabilities: + streaming: true + pushNotifications: false + stateTransitionHistory: false +defaultInputModes: ["text/plain"] +defaultOutputModes: ["text/plain"] +security: + - bearer: [] +security_schemes: + bearer: + type: http + scheme: bearer +``` + +#### Option 2: Inline Configuration + +Alternatively, configure the agent card directly in the main configuration file via `customization.agent_card_config`: ```yaml customization: @@ -143,6 +195,66 @@ authorization: - A2A_JSONRPC ``` +### Persistent State Storage (Multi-Worker Deployments) + +By default, A2A state (task store and context-to-conversation mappings) is stored in memory. This works well for single-worker deployments but causes issues in multi-worker deployments where: + +- Subsequent requests may hit different workers +- Task state and conversation history are lost between workers +- State is lost on service restarts + +For production multi-worker deployments, configure persistent storage using the `a2a_state` section: + +#### In-Memory Storage (Default) + +```yaml +a2a_state: {} +``` + +This is the default. Suitable for single-worker deployments or development. + +#### SQLite Storage + +```yaml +a2a_state: + sqlite: + db_path: "/var/lib/lightspeed/a2a_state.db" +``` + +SQLite is suitable for: +- Single-worker deployments that need persistence across restarts +- Multi-worker deployments with a shared filesystem (e.g., NFS, EFS) + +#### PostgreSQL Storage + +```yaml +a2a_state: + postgres: + host: "postgres.example.com" + port: 5432 + db: "lightspeed" + user: "lightspeed" + password: "secret" + ssl_mode: "require" +``` + +PostgreSQL is recommended for: +- Multi-worker deployments with multiple replicas +- High-availability production deployments +- Scenarios requiring horizontal scaling + +#### What Gets Persisted + +The A2A state storage persists: + +1. **Task Store**: All A2A task objects, enabling task state queries and resumption +2. **Context-to-Conversation Mappings**: Maps A2A `contextId` to Llama Stack `conversation_id` for multi-turn conversations + +This ensures that: +- Multi-turn conversations work correctly across workers +- Task state is queryable regardless of which worker handles the request +- Service restarts don't lose conversation context + ## Agent Card Structure The agent card describes the agent's capabilities: @@ -200,7 +312,7 @@ The `A2AAgentExecutor` class implements the A2A `AgentExecutor` interface: ### Event Flow -``` +```text A2A Request │ ▼ @@ -252,10 +364,12 @@ A2A Request The A2A implementation supports multi-turn conversations: 1. Each A2A `contextId` maps to a Llama Stack `conversation_id` -2. The mapping is stored in memory (`_CONTEXT_TO_CONVERSATION`) +2. The mapping is stored in the configured A2A context store (memory, SQLite, or PostgreSQL) 3. Subsequent messages with the same `contextId` continue the conversation 4. Conversation history is preserved across turns +For multi-worker deployments, configure persistent storage (see [Persistent State Storage](#persistent-state-storage-multi-worker-deployments)) to ensure context mappings are shared across all workers. + ## Testing with A2A Inspector [A2A Inspector](https://github.com/a2aproject/a2a-inspector) is a tool for inspecting, debugging, and validating A2A agents. diff --git a/examples/lightspeed-stack-a2a-state-pg.yaml b/examples/lightspeed-stack-a2a-state-pg.yaml new file mode 100644 index 000000000..0c003a670 --- /dev/null +++ b/examples/lightspeed-stack-a2a-state-pg.yaml @@ -0,0 +1,30 @@ +name: Lightspeed Core Service (LCS) +service: + host: localhost + port: 8080 + base_url: "https://lightspeed.example.com" + auth_enabled: false + workers: 4 + color_log: true + access_log: true +llama_stack: + use_as_library_client: true + library_client_config_path: run.yaml +user_data_collection: + feedback_enabled: true + feedback_storage: "/tmp/data/feedback" + transcripts_enabled: true + transcripts_storage: "/tmp/data/transcripts" +authentication: + module: "noop" +customization: + agent_card_path: agent_card.yaml +a2a_state: + postgres: + host: 127.0.0.1 + port: 5432 + db: lightspeed + user: lightspeed + password: secret + namespace: a2a + ssl_mode: disable diff --git a/examples/lightspeed-stack-a2a-state-sqlite.yaml b/examples/lightspeed-stack-a2a-state-sqlite.yaml new file mode 100644 index 000000000..fd800f966 --- /dev/null +++ b/examples/lightspeed-stack-a2a-state-sqlite.yaml @@ -0,0 +1,24 @@ +name: Lightspeed Core Service (LCS) +service: + host: localhost + port: 8080 + base_url: "http://localhost:8080" + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + use_as_library_client: true + library_client_config_path: run.yaml +user_data_collection: + feedback_enabled: true + feedback_storage: "/tmp/data/feedback" + transcripts_enabled: true + transcripts_storage: "/tmp/data/transcripts" +authentication: + module: "noop" +customization: + agent_card_path: agent_card.yaml +a2a_state: + sqlite: + db_path: /tmp/data/a2a_state.sqlite diff --git a/pyproject.toml b/pyproject.toml index e313c5db3..870364500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ dependencies = [ "openai>=1.99.9", # Used by database interface "sqlalchemy>=2.0.42", + # Async database drivers for A2A persistent storage + "aiosqlite>=0.21.0", + "asyncpg>=0.31.0", # Used by Llama Stack version checker "semver<4.0.0", # Used by authorization resolvers diff --git a/src/a2a_storage/__init__.py b/src/a2a_storage/__init__.py new file mode 100644 index 000000000..2707019cb --- /dev/null +++ b/src/a2a_storage/__init__.py @@ -0,0 +1,23 @@ +"""A2A protocol persistent storage components. + +This module provides storage backends for A2A protocol state, including: +- Task storage (using A2A SDK's TaskStore interface) +- Context-to-conversation mapping storage + +For multi-worker deployments, use SQLite or PostgreSQL backends to ensure +state is shared across all workers. +""" + +from a2a_storage.context_store import A2AContextStore +from a2a_storage.in_memory_context_store import InMemoryA2AContextStore +from a2a_storage.sqlite_context_store import SQLiteA2AContextStore +from a2a_storage.postgres_context_store import PostgresA2AContextStore +from a2a_storage.storage_factory import A2AStorageFactory + +__all__ = [ + "A2AContextStore", + "InMemoryA2AContextStore", + "SQLiteA2AContextStore", + "PostgresA2AContextStore", + "A2AStorageFactory", +] diff --git a/src/a2a_storage/context_store.py b/src/a2a_storage/context_store.py new file mode 100644 index 000000000..2a14deddc --- /dev/null +++ b/src/a2a_storage/context_store.py @@ -0,0 +1,57 @@ +"""Abstract base class for A2A context-to-conversation mapping storage.""" + +from abc import ABC, abstractmethod + + +class A2AContextStore(ABC): + """Abstract base class for storing A2A context-to-conversation mappings. + + This store maps A2A context IDs to Llama Stack conversation IDs to + preserve multi-turn conversation history across requests. + + For multi-worker deployments, implementations should use persistent + storage (SQLite or PostgreSQL) to share state across workers. + """ + + @abstractmethod + async def get(self, context_id: str) -> str | None: + """Retrieve the conversation ID for an A2A context. + + Args: + context_id: The A2A context ID. + + Returns: + The Llama Stack conversation ID, or None if not found. + """ + + @abstractmethod + async def set(self, context_id: str, conversation_id: str) -> None: + """Store a context-to-conversation mapping. + + Args: + context_id: The A2A context ID. + conversation_id: The Llama Stack conversation ID. + """ + + @abstractmethod + async def delete(self, context_id: str) -> None: + """Delete a context-to-conversation mapping. + + Args: + context_id: The A2A context ID to delete. + """ + + @abstractmethod + async def initialize(self) -> None: + """Initialize the store (create tables, etc.). + + This method should be called before using the store. + """ + + @abstractmethod + def ready(self) -> bool: + """Check if the store is ready for use. + + Returns: + True if the store is initialized and ready, False otherwise. + """ diff --git a/src/a2a_storage/in_memory_context_store.py b/src/a2a_storage/in_memory_context_store.py new file mode 100644 index 000000000..0c0c33173 --- /dev/null +++ b/src/a2a_storage/in_memory_context_store.py @@ -0,0 +1,92 @@ +"""In-memory implementation of A2A context store.""" + +import asyncio +import logging + +from a2a_storage.context_store import A2AContextStore + +logger = logging.getLogger(__name__) + + +class InMemoryA2AContextStore(A2AContextStore): + """In-memory implementation of A2A context-to-conversation store. + + Stores context mappings in a dictionary in memory. Data is lost when the + server process stops. This implementation is suitable for single-worker + deployments or development/testing. + + For multi-worker deployments, use SQLiteA2AContextStore or + PostgresA2AContextStore instead. + """ + + def __init__(self) -> None: + """Initialize the in-memory context store.""" + logger.debug("Initializing InMemoryA2AContextStore") + self._contexts: dict[str, str] = {} + self._lock = asyncio.Lock() + self._initialized = True + + async def get(self, context_id: str) -> str | None: + """Retrieve the conversation ID for an A2A context. + + Args: + context_id: The A2A context ID. + + Returns: + The Llama Stack conversation ID, or None if not found. + """ + async with self._lock: + conversation_id = self._contexts.get(context_id) + if conversation_id: + logger.debug( + "Context %s maps to conversation %s", context_id, conversation_id + ) + else: + logger.debug("Context %s not found in store", context_id) + return conversation_id + + async def set(self, context_id: str, conversation_id: str) -> None: + """Store a context-to-conversation mapping. + + Args: + context_id: The A2A context ID. + conversation_id: The Llama Stack conversation ID. + """ + async with self._lock: + self._contexts[context_id] = conversation_id + logger.debug( + "Stored mapping: context %s -> conversation %s", + context_id, + conversation_id, + ) + + async def delete(self, context_id: str) -> None: + """Delete a context-to-conversation mapping. + + Args: + context_id: The A2A context ID to delete. + """ + async with self._lock: + if context_id in self._contexts: + del self._contexts[context_id] + logger.debug("Deleted context mapping for %s", context_id) + else: + logger.debug( + "Attempted to delete non-existent context mapping: %s", context_id + ) + + async def initialize(self) -> None: + """Initialize the store. + + For in-memory store, this is a no-op as initialization happens + in __init__. + """ + logger.debug("InMemoryA2AContextStore initialized") + + def ready(self) -> bool: + """Check if the store is ready for use. + + Returns: + True, as in-memory store is always ready after construction. + """ + return self._initialized diff --git a/src/a2a_storage/postgres_context_store.py b/src/a2a_storage/postgres_context_store.py new file mode 100644 index 000000000..b9ead766c --- /dev/null +++ b/src/a2a_storage/postgres_context_store.py @@ -0,0 +1,142 @@ +"""PostgreSQL implementation of A2A context store.""" + +import logging + +from sqlalchemy import Column, String, Table, MetaData, select, delete +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker + +from a2a_storage.context_store import A2AContextStore + +logger = logging.getLogger(__name__) + +# Define the table metadata +metadata = MetaData() + +a2a_context_table = Table( + "a2a_contexts", + metadata, + Column("context_id", String, primary_key=True), + Column("conversation_id", String, nullable=False), +) + + +class PostgresA2AContextStore(A2AContextStore): + """PostgreSQL implementation of A2A context-to-conversation store. + + Stores context mappings in a PostgreSQL database for persistence across + restarts and sharing across workers in multi-worker deployments. + + The store creates a table 'a2a_contexts' with the following schema: + context_id (VARCHAR, PRIMARY KEY): The A2A context ID + conversation_id (VARCHAR, NOT NULL): The Llama Stack conversation ID + """ + + def __init__( + self, + engine: AsyncEngine, + create_table: bool = True, + ) -> None: + """Initialize the PostgreSQL context store. + + Args: + engine: SQLAlchemy async engine connected to the PostgreSQL database. + create_table: If True, create the table on initialization. + """ + logger.debug("Initializing PostgresA2AContextStore") + self._engine = engine + self._session_maker = async_sessionmaker(engine, expire_on_commit=False) + self._create_table = create_table + self._initialized = False + + async def initialize(self) -> None: + """Initialize the store and create tables if needed.""" + if self._initialized: + return + + logger.debug("Initializing PostgreSQL A2A context store schema") + if self._create_table: + async with self._engine.begin() as conn: + await conn.run_sync(metadata.create_all) + self._initialized = True + logger.info("PostgresA2AContextStore initialized successfully") + + async def _ensure_initialized(self) -> None: + """Ensure the store is initialized before use.""" + if not self._initialized: + await self.initialize() + + async def get(self, context_id: str) -> str | None: + """Retrieve the conversation ID for an A2A context. + + Args: + context_id: The A2A context ID. + + Returns: + The Llama Stack conversation ID, or None if not found. + """ + await self._ensure_initialized() + + async with self._session_maker() as session: + stmt = select(a2a_context_table.c.conversation_id).where( + a2a_context_table.c.context_id == context_id + ) + result = await session.execute(stmt) + row = result.scalar_one_or_none() + + if row: + logger.debug("Context %s maps to conversation %s", context_id, row) + return row + logger.debug("Context %s not found in store", context_id) + return None + + async def set(self, context_id: str, conversation_id: str) -> None: + """Store a context-to-conversation mapping. + + Uses PostgreSQL's INSERT ... ON CONFLICT for upsert behavior. + + Args: + context_id: The A2A context ID. + conversation_id: The Llama Stack conversation ID. + """ + await self._ensure_initialized() + + async with self._session_maker.begin() as session: + # Use PostgreSQL's INSERT ... ON CONFLICT (upsert) + stmt = pg_insert(a2a_context_table).values( + context_id=context_id, + conversation_id=conversation_id, + ) + stmt = stmt.on_conflict_do_update( + index_elements=["context_id"], + set_={"conversation_id": conversation_id}, + ) + await session.execute(stmt) + logger.debug( + "Stored mapping: context %s -> conversation %s", + context_id, + conversation_id, + ) + + async def delete(self, context_id: str) -> None: + """Delete a context-to-conversation mapping. + + Args: + context_id: The A2A context ID to delete. + """ + await self._ensure_initialized() + + async with self._session_maker.begin() as session: + stmt = delete(a2a_context_table).where( + a2a_context_table.c.context_id == context_id + ) + await session.execute(stmt) + logger.debug("Deleted context mapping for %s", context_id) + + def ready(self) -> bool: + """Check if the store is ready for use. + + Returns: + True if the store is initialized, False otherwise. + """ + return self._initialized diff --git a/src/a2a_storage/sqlite_context_store.py b/src/a2a_storage/sqlite_context_store.py new file mode 100644 index 000000000..edf201e5d --- /dev/null +++ b/src/a2a_storage/sqlite_context_store.py @@ -0,0 +1,143 @@ +"""SQLite implementation of A2A context store.""" + +import logging + +from sqlalchemy import Column, String, Table, MetaData, select, delete +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker + +from a2a_storage.context_store import A2AContextStore + +logger = logging.getLogger(__name__) + +# Define the table metadata +metadata = MetaData() + +a2a_context_table = Table( + "a2a_contexts", + metadata, + Column("context_id", String, primary_key=True), + Column("conversation_id", String, nullable=False), +) + + +class SQLiteA2AContextStore(A2AContextStore): + """SQLite implementation of A2A context-to-conversation store. + + Stores context mappings in a SQLite database for persistence across + restarts and sharing across workers (when using a shared database file). + + The store creates a table 'a2a_contexts' with the following schema: + context_id (TEXT, PRIMARY KEY): The A2A context ID + conversation_id (TEXT, NOT NULL): The Llama Stack conversation ID + """ + + def __init__( + self, + engine: AsyncEngine, + create_table: bool = True, + ) -> None: + """Initialize the SQLite context store. + + Args: + engine: SQLAlchemy async engine connected to the SQLite database. + create_table: If True, create the table on initialization. + """ + logger.debug("Initializing SQLiteA2AContextStore") + self._engine = engine + self._session_maker = async_sessionmaker(engine, expire_on_commit=False) + self._create_table = create_table + self._initialized = False + + async def initialize(self) -> None: + """Initialize the store and create tables if needed.""" + if self._initialized: + return + + logger.debug("Initializing SQLite A2A context store schema") + if self._create_table: + async with self._engine.begin() as conn: + await conn.run_sync(metadata.create_all) + self._initialized = True + logger.info("SQLiteA2AContextStore initialized successfully") + + async def _ensure_initialized(self) -> None: + """Ensure the store is initialized before use.""" + if not self._initialized: + await self.initialize() + + async def get(self, context_id: str) -> str | None: + """Retrieve the conversation ID for an A2A context. + + Args: + context_id: The A2A context ID. + + Returns: + The Llama Stack conversation ID, or None if not found. + """ + await self._ensure_initialized() + + async with self._session_maker() as session: + stmt = select(a2a_context_table.c.conversation_id).where( + a2a_context_table.c.context_id == context_id + ) + result = await session.execute(stmt) + row = result.scalar_one_or_none() + + if row: + logger.debug("Context %s maps to conversation %s", context_id, row) + return row + logger.debug("Context %s not found in store", context_id) + return None + + async def set(self, context_id: str, conversation_id: str) -> None: + """Store a context-to-conversation mapping. + + Uses delete-then-insert to handle both new and existing mappings. + + Args: + context_id: The A2A context ID. + conversation_id: The Llama Stack conversation ID. + """ + await self._ensure_initialized() + + async with self._session_maker.begin() as session: + # Upsert by deleting existing row and inserting new values + await session.execute( + a2a_context_table.delete().where( + a2a_context_table.c.context_id == context_id + ) + ) + await session.execute( + a2a_context_table.insert().values( + context_id=context_id, + conversation_id=conversation_id, + ) + ) + logger.debug( + "Stored mapping: context %s -> conversation %s", + context_id, + conversation_id, + ) + + async def delete(self, context_id: str) -> None: + """Delete a context-to-conversation mapping. + + Args: + context_id: The A2A context ID to delete. + """ + await self._ensure_initialized() + + async with self._session_maker.begin() as session: + stmt = delete(a2a_context_table).where( + a2a_context_table.c.context_id == context_id + ) + await session.execute(stmt) + logger.debug("Deleted context mapping for %s", context_id) + + def ready(self) -> bool: + """Check if the store is ready for use. + + Returns: + True if the store is initialized, False otherwise. + """ + return self._initialized diff --git a/src/a2a_storage/storage_factory.py b/src/a2a_storage/storage_factory.py new file mode 100644 index 000000000..59e91ec3f --- /dev/null +++ b/src/a2a_storage/storage_factory.py @@ -0,0 +1,184 @@ +"""Factory for creating A2A storage backends.""" + +import logging +from urllib.parse import quote_plus + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + +from a2a.server.tasks import TaskStore, InMemoryTaskStore, DatabaseTaskStore + +from a2a_storage.context_store import A2AContextStore +from a2a_storage.in_memory_context_store import InMemoryA2AContextStore +from a2a_storage.sqlite_context_store import SQLiteA2AContextStore +from a2a_storage.postgres_context_store import PostgresA2AContextStore +from models.config import A2AStateConfiguration + +logger = logging.getLogger(__name__) + + +class A2AStorageFactory: + """Factory for creating A2A storage backends. + + Creates appropriate TaskStore and A2AContextStore implementations based + on the A2A state configuration. For multi-worker deployments, this factory + creates database-backed stores that share state across workers. + """ + + _engine: AsyncEngine | None = None + _task_store: TaskStore | None = None + _context_store: A2AContextStore | None = None + + @classmethod + async def create_task_store(cls, config: A2AStateConfiguration) -> TaskStore: + """Create a TaskStore based on configuration. + + Args: + config: A2A state configuration. + + Returns: + TaskStore implementation (InMemoryTaskStore or DatabaseTaskStore). + """ + if cls._task_store is not None: + return cls._task_store + + match config.storage_type: + case "memory": + logger.info("Creating in-memory A2A task store") + cls._task_store = InMemoryTaskStore() + case "sqlite": + if config.sqlite is None: + raise ValueError("SQLite configuration required") + logger.info( + "Creating SQLite A2A task store at %s", config.sqlite.db_path + ) + engine = await cls._get_or_create_engine(config) + cls._task_store = DatabaseTaskStore( + engine, create_table=True, table_name="a2a_tasks" + ) + await cls._task_store.initialize() + case "postgres": + if config.postgres is None: + raise ValueError("PostgreSQL configuration required") + logger.info( + "Creating PostgreSQL A2A task store at %s:%s", + config.postgres.host, + config.postgres.port, + ) + engine = await cls._get_or_create_engine(config) + cls._task_store = DatabaseTaskStore( + engine, create_table=True, table_name="a2a_tasks" + ) + await cls._task_store.initialize() + case _: + raise ValueError(f"Unknown A2A state type: {config.storage_type}") + + return cls._task_store + + @classmethod + async def create_context_store( + cls, config: A2AStateConfiguration + ) -> A2AContextStore: + """Create an A2AContextStore based on configuration. + + Args: + config: A2A state configuration. + + Returns: + A2AContextStore implementation. + """ + if cls._context_store is not None: + return cls._context_store + + match config.storage_type: + case "memory": + logger.info("Creating in-memory A2A context store") + cls._context_store = InMemoryA2AContextStore() + case "sqlite": + if config.sqlite is None: + raise ValueError("SQLite configuration required") + logger.info( + "Creating SQLite A2A context store at %s", config.sqlite.db_path + ) + engine = await cls._get_or_create_engine(config) + cls._context_store = SQLiteA2AContextStore(engine, create_table=True) + await cls._context_store.initialize() + case "postgres": + if config.postgres is None: + raise ValueError("PostgreSQL configuration required") + logger.info( + "Creating PostgreSQL A2A context store at %s:%s", + config.postgres.host, + config.postgres.port, + ) + engine = await cls._get_or_create_engine(config) + cls._context_store = PostgresA2AContextStore(engine, create_table=True) + await cls._context_store.initialize() + case _: + raise ValueError(f"Unknown A2A state type: {config.storage_type}") + + return cls._context_store + + @classmethod + async def _get_or_create_engine(cls, config: A2AStateConfiguration) -> AsyncEngine: + """Get or create the SQLAlchemy async engine. + + The engine is reused for both task and context stores to share + the connection pool. + + Args: + config: A2A state configuration. + + Returns: + SQLAlchemy AsyncEngine. + """ + if cls._engine is not None: + return cls._engine + + match config.storage_type: + case "sqlite": + if config.sqlite is None: + raise ValueError("SQLite configuration required") + connection_string = f"sqlite+aiosqlite:///{config.sqlite.db_path}" + cls._engine = create_async_engine( + connection_string, + echo=False, + ) + case "postgres": + if config.postgres is None: + raise ValueError("PostgreSQL configuration required") + pg = config.postgres + password = ( + quote_plus(pg.password.get_secret_value()) if pg.password else "" + ) + connection_string = ( + f"postgresql+asyncpg://{pg.user}:{password}" + f"@{pg.host}:{pg.port}/{pg.db}" + ) + cls._engine = create_async_engine( + connection_string, + echo=False, + ) + case _: + raise ValueError( + f"Cannot create engine for storage type: {config.storage_type}" + ) + + logger.info("Created async database engine for A2A storage") + return cls._engine + + @classmethod + async def cleanup(cls) -> None: + """Clean up resources (close database connections).""" + if cls._engine is not None: + await cls._engine.dispose() + cls._engine = None + logger.info("Closed A2A storage database engine") + cls._task_store = None + cls._context_store = None + + @classmethod + def reset(cls) -> None: + """Reset factory state (for testing purposes).""" + cls._engine = None + cls._task_store = None + cls._context_store = None diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 65a4834ea..d7baaf51a 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -30,7 +30,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler -from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import TaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.server.apps import A2AStarletteApplication from a2a.utils import new_agent_text_message, new_task @@ -39,6 +39,7 @@ from authentication import get_auth_dependency from authorization.middleware import authorize from configuration import configuration +from a2a_storage import A2AStorageFactory, A2AContextStore from models.config import Action from models.requests import QueryRequest from app.endpoints.query import ( @@ -60,12 +61,37 @@ # ----------------------------- # Persistent State (multi-turn) # ----------------------------- -# Keep a single TaskStore instance so tasks persist across requests and -# previous messages remain connected to the current request. -_TASK_STORE = InMemoryTaskStore() +# Task store and context store are created lazily based on configuration. +# For multi-worker deployments, configure 'a2a_state' with 'sqlite' or 'postgres' +# to share state across workers. +_TASK_STORE: TaskStore | None = None +_CONTEXT_STORE: A2AContextStore | None = None -# Map A2A contextId -> Llama Stack conversationId to preserve history across turns -_CONTEXT_TO_CONVERSATION: dict[str, str] = {} + +async def _get_task_store() -> TaskStore: + """Get the A2A task store, creating it if necessary. + + Returns: + TaskStore instance based on configuration. + """ + global _TASK_STORE # pylint: disable=global-statement + if _TASK_STORE is None: + _TASK_STORE = await A2AStorageFactory.create_task_store(configuration.a2a_state) + return _TASK_STORE + + +async def _get_context_store() -> A2AContextStore: + """Get the A2A context store, creating it if necessary. + + Returns: + A2AContextStore instance based on configuration. + """ + global _CONTEXT_STORE # pylint: disable=global-statement + if _CONTEXT_STORE is None: + _CONTEXT_STORE = await A2AStorageFactory.create_context_store( + configuration.a2a_state + ) + return _CONTEXT_STORE def _convert_responses_content_to_a2a_parts(output: list[Any]) -> list[Part]: @@ -261,7 +287,8 @@ async def _process_task_streaming( # pylint: disable=too-many-locals # Resolve conversation_id from A2A contextId for multi-turn a2a_context_id = context_id - conversation_id = _CONTEXT_TO_CONVERSATION.get(a2a_context_id) + context_store = await _get_context_store() + conversation_id = await context_store.get(a2a_context_id) logger.info( "A2A contextId %s maps to conversation_id %s", a2a_context_id, @@ -299,7 +326,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals # Persist conversation_id for next turn in same A2A context if conversation_id: - _CONTEXT_TO_CONVERSATION[a2a_context_id] = conversation_id + await context_store.set(a2a_context_id, conversation_id) logger.info( "Persisted conversation_id %s for A2A contextId %s", conversation_id, @@ -593,7 +620,9 @@ async def get_agent_card( # pylint: disable=unused-argument raise -def _create_a2a_app(auth_token: str, mcp_headers: dict[str, dict[str, str]]) -> Any: +async def _create_a2a_app( + auth_token: str, mcp_headers: dict[str, dict[str, str]] +) -> Any: """Create an A2A Starlette application instance with auth context. Args: @@ -604,10 +633,11 @@ def _create_a2a_app(auth_token: str, mcp_headers: dict[str, dict[str, str]]) -> A2A Starlette ASGI application """ agent_executor = A2AAgentExecutor(auth_token=auth_token, mcp_headers=mcp_headers) + task_store = await _get_task_store() request_handler = DefaultRequestHandler( agent_executor=agent_executor, - task_store=_TASK_STORE, + task_store=task_store, ) a2a_app = A2AStarletteApplication( @@ -656,7 +686,7 @@ async def handle_a2a_jsonrpc( # pylint: disable=too-many-locals,too-many-statem auth_token = "" # Create A2A app with auth context - a2a_app = _create_a2a_app(auth_token, mcp_headers) + a2a_app = await _create_a2a_app(auth_token, mcp_headers) # Detect if this is a streaming request by checking the JSON-RPC method is_streaming_request = False diff --git a/src/configuration.py b/src/configuration.py index 9d698317f..ef9ff915c 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -9,6 +9,7 @@ import yaml from models.config import ( + A2AStateConfiguration, AuthorizationConfiguration, Configuration, Customization, @@ -158,6 +159,13 @@ def quota_handlers_configuration(self) -> QuotaHandlersConfiguration: raise LogicError("logic error: configuration is not loaded") return self._configuration.quota_handlers + @property + def a2a_state(self) -> "A2AStateConfiguration": + """Return A2A state configuration.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.a2a_state + @property def conversation_cache(self) -> Cache: """Return the conversation cache.""" diff --git a/src/models/config.py b/src/models/config.py index 32b357752..87ad24389 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1247,6 +1247,64 @@ def check_cache_configuration(self) -> Self: return self +class A2AStateConfiguration(ConfigurationBase): + """A2A protocol persistent state configuration. + + Configures how A2A task state and context-to-conversation mappings are + stored. For multi-worker deployments, use SQLite or PostgreSQL to ensure + state is shared across all workers. + + If no configuration is provided, in-memory storage is used (default). + This is suitable for single-worker deployments but state will be lost + on restarts and not shared across workers. + + Attributes: + sqlite: SQLite database configuration for A2A state storage. + postgres: PostgreSQL database configuration for A2A state storage. + """ + + sqlite: Optional[SQLiteDatabaseConfiguration] = Field( + default=None, + title="SQLite configuration", + description="SQLite database configuration for A2A state storage.", + ) + postgres: Optional[PostgreSQLDatabaseConfiguration] = Field( + default=None, + title="PostgreSQL configuration", + description="PostgreSQL database configuration for A2A state storage.", + ) + + @model_validator(mode="after") + def check_a2a_state_configuration(self) -> Self: + """Validate A2A state configuration - only one type can be configured.""" + total_configured = sum([self.sqlite is not None, self.postgres is not None]) + + if total_configured > 1: + raise ValueError("Only one A2A state storage configuration can be provided") + + return self + + @property + def storage_type(self) -> Literal["memory", "sqlite", "postgres"]: + """Return the configured storage type.""" + if self.sqlite is not None: + return "sqlite" + if self.postgres is not None: + return "postgres" + return "memory" + + @property + def config( + self, + ) -> SQLiteDatabaseConfiguration | PostgreSQLDatabaseConfiguration | None: + """Return the active storage configuration, or None for memory storage.""" + if self.sqlite is not None: + return self.sqlite + if self.postgres is not None: + return self.postgres + return None + + class ByokRag(ConfigurationBase): """BYOK (Bring Your Own Knowledge) RAG configuration.""" @@ -1488,6 +1546,12 @@ class Configuration(ConfigurationBase): "reconfigure Llama Stack through its run.yaml configuration file", ) + a2a_state: A2AStateConfiguration = Field( + default_factory=A2AStateConfiguration, + title="A2A state configuration", + description="Configuration for A2A protocol persistent state storage.", + ) + quota_handlers: QuotaHandlersConfiguration = Field( default_factory=lambda: QuotaHandlersConfiguration( sqlite=None, postgres=None, enable_token_history=False diff --git a/tests/unit/a2a_storage/__init__.py b/tests/unit/a2a_storage/__init__.py new file mode 100644 index 000000000..f714787f8 --- /dev/null +++ b/tests/unit/a2a_storage/__init__.py @@ -0,0 +1 @@ +"""Unit tests for A2A storage module.""" diff --git a/tests/unit/a2a_storage/test_in_memory_context_store.py b/tests/unit/a2a_storage/test_in_memory_context_store.py new file mode 100644 index 000000000..8da998b3d --- /dev/null +++ b/tests/unit/a2a_storage/test_in_memory_context_store.py @@ -0,0 +1,95 @@ +"""Unit tests for InMemoryA2AContextStore.""" + +import pytest + +from a2a_storage.in_memory_context_store import InMemoryA2AContextStore + + +class TestInMemoryA2AContextStore: + """Tests for InMemoryA2AContextStore.""" + + @pytest.fixture + def store(self) -> InMemoryA2AContextStore: + """Create a fresh in-memory context store for each test.""" + return InMemoryA2AContextStore() + + @pytest.mark.asyncio + async def test_initialization(self, store: InMemoryA2AContextStore) -> None: + """Test store initialization.""" + assert store.ready() is True + + @pytest.mark.asyncio + async def test_get_nonexistent_key(self, store: InMemoryA2AContextStore) -> None: + """Test getting a key that doesn't exist returns None.""" + result = await store.get("nonexistent-context-id") + assert result is None + + @pytest.mark.asyncio + async def test_set_and_get(self, store: InMemoryA2AContextStore) -> None: + """Test setting and getting a context mapping.""" + context_id = "ctx-123" + conversation_id = "conv-456" + + await store.set(context_id, conversation_id) + result = await store.get(context_id) + + assert result == conversation_id + + @pytest.mark.asyncio + async def test_set_overwrites_existing( + self, store: InMemoryA2AContextStore + ) -> None: + """Test that set overwrites an existing mapping.""" + context_id = "ctx-123" + conversation_id_1 = "conv-456" + conversation_id_2 = "conv-789" + + await store.set(context_id, conversation_id_1) + await store.set(context_id, conversation_id_2) + result = await store.get(context_id) + + assert result == conversation_id_2 + + @pytest.mark.asyncio + async def test_delete_existing_key(self, store: InMemoryA2AContextStore) -> None: + """Test deleting an existing key.""" + context_id = "ctx-123" + conversation_id = "conv-456" + + await store.set(context_id, conversation_id) + await store.delete(context_id) + result = await store.get(context_id) + + assert result is None + + @pytest.mark.asyncio + async def test_delete_nonexistent_key(self, store: InMemoryA2AContextStore) -> None: + """Test deleting a key that doesn't exist (should not raise).""" + # Should not raise any exception + await store.delete("nonexistent-context-id") + + @pytest.mark.asyncio + async def test_multiple_contexts(self, store: InMemoryA2AContextStore) -> None: + """Test storing multiple context mappings.""" + mappings = { + "ctx-1": "conv-1", + "ctx-2": "conv-2", + "ctx-3": "conv-3", + } + + for ctx_id, conv_id in mappings.items(): + await store.set(ctx_id, conv_id) + + for ctx_id, expected_conv_id in mappings.items(): + result = await store.get(ctx_id) + assert result == expected_conv_id + + @pytest.mark.asyncio + async def test_initialize_is_noop(self, store: InMemoryA2AContextStore) -> None: + """Test that initialize is a no-op for in-memory store.""" + # Store should already be ready + assert store.ready() is True + + # Initialize should not change anything + await store.initialize() + assert store.ready() is True diff --git a/tests/unit/a2a_storage/test_sqlite_context_store.py b/tests/unit/a2a_storage/test_sqlite_context_store.py new file mode 100644 index 000000000..c2f7baa16 --- /dev/null +++ b/tests/unit/a2a_storage/test_sqlite_context_store.py @@ -0,0 +1,147 @@ +"""Unit tests for SQLiteA2AContextStore.""" + +from pathlib import Path + +import pytest +from sqlalchemy.ext.asyncio import create_async_engine + +from a2a_storage.sqlite_context_store import SQLiteA2AContextStore + + +class TestSQLiteA2AContextStore: + """Tests for SQLiteA2AContextStore.""" + + @pytest.fixture + async def store(self, tmp_path: Path) -> SQLiteA2AContextStore: + """Create a fresh SQLite context store for each test.""" + db_path = tmp_path / "test_a2a_context.db" + engine = create_async_engine( + f"sqlite+aiosqlite:///{db_path}", + echo=False, + ) + context_store = SQLiteA2AContextStore(engine, create_table=True) + await context_store.initialize() + return context_store + + @pytest.mark.asyncio + async def test_initialization(self, store: SQLiteA2AContextStore) -> None: + """Test store initialization.""" + assert store.ready() is True + + @pytest.mark.asyncio + async def test_not_ready_before_initialize(self, tmp_path: Path) -> None: + """Test store is not ready before initialization.""" + db_path = tmp_path / "test_a2a_context_uninit.db" + engine = create_async_engine( + f"sqlite+aiosqlite:///{db_path}", + echo=False, + ) + context_store = SQLiteA2AContextStore(engine, create_table=True) + assert context_store.ready() is False + + @pytest.mark.asyncio + async def test_get_nonexistent_key(self, store: SQLiteA2AContextStore) -> None: + """Test getting a key that doesn't exist returns None.""" + result = await store.get("nonexistent-context-id") + assert result is None + + @pytest.mark.asyncio + async def test_set_and_get(self, store: SQLiteA2AContextStore) -> None: + """Test setting and getting a context mapping.""" + context_id = "ctx-123" + conversation_id = "conv-456" + + await store.set(context_id, conversation_id) + result = await store.get(context_id) + + assert result == conversation_id + + @pytest.mark.asyncio + async def test_set_overwrites_existing(self, store: SQLiteA2AContextStore) -> None: + """Test that set overwrites an existing mapping.""" + context_id = "ctx-123" + conversation_id_1 = "conv-456" + conversation_id_2 = "conv-789" + + await store.set(context_id, conversation_id_1) + await store.set(context_id, conversation_id_2) + result = await store.get(context_id) + + assert result == conversation_id_2 + + @pytest.mark.asyncio + async def test_delete_existing_key(self, store: SQLiteA2AContextStore) -> None: + """Test deleting an existing key.""" + context_id = "ctx-123" + conversation_id = "conv-456" + + await store.set(context_id, conversation_id) + await store.delete(context_id) + result = await store.get(context_id) + + assert result is None + + @pytest.mark.asyncio + async def test_delete_nonexistent_key(self, store: SQLiteA2AContextStore) -> None: + """Test deleting a key that doesn't exist (should not raise).""" + # Should not raise any exception + await store.delete("nonexistent-context-id") + + @pytest.mark.asyncio + async def test_multiple_contexts(self, store: SQLiteA2AContextStore) -> None: + """Test storing multiple context mappings.""" + mappings = { + "ctx-1": "conv-1", + "ctx-2": "conv-2", + "ctx-3": "conv-3", + } + + for ctx_id, conv_id in mappings.items(): + await store.set(ctx_id, conv_id) + + for ctx_id, expected_conv_id in mappings.items(): + result = await store.get(ctx_id) + assert result == expected_conv_id + + @pytest.mark.asyncio + async def test_persistence_across_operations( + self, store: SQLiteA2AContextStore + ) -> None: + """Test that data persists after multiple operations.""" + # Set multiple values + await store.set("ctx-1", "conv-1") + await store.set("ctx-2", "conv-2") + + # Delete one + await store.delete("ctx-1") + + # Update one + await store.set("ctx-2", "conv-2-updated") + + # Add one more + await store.set("ctx-3", "conv-3") + + # Verify state + assert await store.get("ctx-1") is None + assert await store.get("ctx-2") == "conv-2-updated" + assert await store.get("ctx-3") == "conv-3" + + @pytest.mark.asyncio + async def test_auto_initialize_on_operations(self, tmp_path: Path) -> None: + """Test that store auto-initializes on first operation.""" + db_path = tmp_path / "test_auto_init.db" + engine = create_async_engine( + f"sqlite+aiosqlite:///{db_path}", + echo=False, + ) + store = SQLiteA2AContextStore(engine, create_table=True) + + # Don't call initialize(), just use the store + assert store.ready() is False + + # This should auto-initialize + await store.set("ctx-1", "conv-1") + assert store.ready() is True + + result = await store.get("ctx-1") + assert result == "conv-1" diff --git a/tests/unit/a2a_storage/test_storage_factory.py b/tests/unit/a2a_storage/test_storage_factory.py new file mode 100644 index 000000000..6ca54d5ab --- /dev/null +++ b/tests/unit/a2a_storage/test_storage_factory.py @@ -0,0 +1,172 @@ +"""Unit tests for A2AStorageFactory.""" + +# pylint: disable=protected-access + +from pathlib import Path +from typing import Generator +from unittest.mock import PropertyMock + +import pytest +from pytest_mock import MockerFixture + +from a2a.server.tasks import InMemoryTaskStore, DatabaseTaskStore + +from a2a_storage import A2AStorageFactory +from a2a_storage.in_memory_context_store import InMemoryA2AContextStore +from a2a_storage.sqlite_context_store import SQLiteA2AContextStore +from models.config import A2AStateConfiguration, SQLiteDatabaseConfiguration + + +class TestA2AStorageFactory: + """Tests for A2AStorageFactory.""" + + @pytest.fixture(autouse=True) + def reset_factory(self) -> Generator[None, None, None]: + """Reset factory state before each test.""" + A2AStorageFactory.reset() + yield + A2AStorageFactory.reset() + + @pytest.mark.asyncio + async def test_create_memory_task_store(self) -> None: + """Test creating an in-memory task store (default, no config).""" + config = A2AStateConfiguration() + + store = await A2AStorageFactory.create_task_store(config) + + assert isinstance(store, InMemoryTaskStore) + + @pytest.mark.asyncio + async def test_create_memory_context_store(self) -> None: + """Test creating an in-memory context store (default, no config).""" + config = A2AStateConfiguration() + + store = await A2AStorageFactory.create_context_store(config) + + assert isinstance(store, InMemoryA2AContextStore) + assert store.ready() is True + + @pytest.mark.asyncio + async def test_create_sqlite_task_store(self, tmp_path: Path) -> None: + """Test creating a SQLite task store.""" + db_path = tmp_path / "test_task_store.db" + sqlite_config = SQLiteDatabaseConfiguration(db_path=str(db_path)) + config = A2AStateConfiguration(sqlite=sqlite_config) + + store = await A2AStorageFactory.create_task_store(config) + + assert isinstance(store, DatabaseTaskStore) + + @pytest.mark.asyncio + async def test_create_sqlite_context_store(self, tmp_path: Path) -> None: + """Test creating a SQLite context store.""" + db_path = tmp_path / "test_context_store.db" + sqlite_config = SQLiteDatabaseConfiguration(db_path=str(db_path)) + config = A2AStateConfiguration(sqlite=sqlite_config) + + store = await A2AStorageFactory.create_context_store(config) + + assert isinstance(store, SQLiteA2AContextStore) + assert store.ready() is True + + @pytest.mark.asyncio + async def test_factory_reuses_task_store(self) -> None: + """Test that factory reuses the same task store instance.""" + config = A2AStateConfiguration() + + store1 = await A2AStorageFactory.create_task_store(config) + store2 = await A2AStorageFactory.create_task_store(config) + + assert store1 is store2 + + @pytest.mark.asyncio + async def test_factory_reuses_context_store(self) -> None: + """Test that factory reuses the same context store instance.""" + config = A2AStateConfiguration() + + store1 = await A2AStorageFactory.create_context_store(config) + store2 = await A2AStorageFactory.create_context_store(config) + + assert store1 is store2 + + @pytest.mark.asyncio + async def test_cleanup_disposes_state(self) -> None: + """Test that cleanup disposes the stores.""" + config = A2AStateConfiguration() + + await A2AStorageFactory.create_task_store(config) + await A2AStorageFactory.create_context_store(config) + + assert A2AStorageFactory._task_store is not None + assert A2AStorageFactory._context_store is not None + + await A2AStorageFactory.cleanup() + + assert A2AStorageFactory._engine is None + assert A2AStorageFactory._task_store is None + assert A2AStorageFactory._context_store is None + + @pytest.mark.asyncio + async def test_reset_clears_state(self) -> None: + """Test that reset clears all factory state.""" + config = A2AStateConfiguration() + + await A2AStorageFactory.create_task_store(config) + await A2AStorageFactory.create_context_store(config) + + A2AStorageFactory.reset() + + assert A2AStorageFactory._engine is None + assert A2AStorageFactory._task_store is None + assert A2AStorageFactory._context_store is None + + @pytest.mark.asyncio + async def test_invalid_storage_type_raises_error( + self, mocker: MockerFixture + ) -> None: + """Test that an invalid storage type raises ValueError.""" + config = A2AStateConfiguration() + + # Mock the storage_type property to return an invalid value + mocker.patch.object( + A2AStateConfiguration, + "storage_type", + new_callable=PropertyMock, + return_value="invalid", + ) + with pytest.raises(ValueError, match="Unknown A2A state type"): + await A2AStorageFactory.create_task_store(config) + + @pytest.mark.asyncio + async def test_sqlite_storage_type_without_config_raises_error( + self, mocker: MockerFixture + ) -> None: + """Test that SQLite storage type without config raises ValueError.""" + config = A2AStateConfiguration() + + # Mock to simulate misconfiguration + mocker.patch.object( + A2AStateConfiguration, + "storage_type", + new_callable=PropertyMock, + return_value="sqlite", + ) + with pytest.raises(ValueError, match="SQLite configuration required"): + await A2AStorageFactory.create_task_store(config) + + @pytest.mark.asyncio + async def test_postgres_storage_type_without_config_raises_error( + self, mocker: MockerFixture + ) -> None: + """Test that PostgreSQL storage type without config raises ValueError.""" + config = A2AStateConfiguration() + + # Mock to simulate misconfiguration + mocker.patch.object( + A2AStateConfiguration, + "storage_type", + new_callable=PropertyMock, + return_value="postgres", + ) + with pytest.raises(ValueError, match="PostgreSQL configuration required"): + await A2AStorageFactory.create_task_store(config) diff --git a/tests/unit/app/endpoints/test_a2a.py b/tests/unit/app/endpoints/test_a2a.py index 1e3199782..265303074 100644 --- a/tests/unit/app/endpoints/test_a2a.py +++ b/tests/unit/app/endpoints/test_a2a.py @@ -29,8 +29,8 @@ get_lightspeed_agent_card, A2AAgentExecutor, TaskResultAggregator, - _CONTEXT_TO_CONVERSATION, - _TASK_STORE, + _get_task_store, + _get_context_store, a2a_health_check, get_agent_card, ) @@ -104,6 +104,7 @@ def setup_configuration_fixture(mocker: MockerFixture) -> AppConfig: }, "authentication": {"module": "noop"}, "authorization": {"access_rules": []}, + "a2a_state": {}, # Empty = in-memory storage (default) } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -130,6 +131,7 @@ def setup_minimal_configuration_fixture(mocker: MockerFixture) -> AppConfig: "customization": {}, # Empty customization, no agent_card_config "authentication": {"module": "noop"}, "authorization": {"access_rules": []}, + "a2a_state": {}, # Empty = in-memory storage (default) } cfg = AppConfig() cfg.init_from_dict(config_dict) @@ -613,13 +615,42 @@ async def test_cancel_raises_not_implemented(self) -> None: class TestContextToConversationMapping: """Tests for the context to conversation ID mapping.""" - def test_context_to_conversation_is_dict(self) -> None: - """Test that _CONTEXT_TO_CONVERSATION is a dict.""" - assert isinstance(_CONTEXT_TO_CONVERSATION, dict) + @pytest.mark.asyncio + async def test_get_context_store_returns_store( + self, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test that _get_context_store returns a context store.""" + # pylint: disable=import-outside-toplevel + # Reset module-level state and factory + import app.endpoints.a2a as a2a_module + from a2a_storage import A2AStorageFactory + + a2a_module._context_store = None + a2a_module._task_store = None + A2AStorageFactory.reset() - def test_task_store_exists(self) -> None: - """Test that _TASK_STORE exists.""" - assert _TASK_STORE is not None + store = await _get_context_store() + assert store is not None + assert store.ready() is True + + @pytest.mark.asyncio + async def test_get_task_store_returns_store( + self, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test that _get_task_store returns a task store.""" + # pylint: disable=import-outside-toplevel + # Reset module-level state and factory + import app.endpoints.a2a as a2a_module + from a2a_storage import A2AStorageFactory + + a2a_module._context_store = None + a2a_module._task_store = None + A2AStorageFactory.reset() + + store = await _get_task_store() + assert store is not None # ----------------------------- diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 396f1e48c..cea1876ee 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -66,7 +66,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 16 + assert len(app.routers) == 17 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -94,7 +94,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 16 + assert len(app.routers) == 17 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" diff --git a/tests/unit/models/config/test_a2a_state_configuration.py b/tests/unit/models/config/test_a2a_state_configuration.py new file mode 100644 index 000000000..3f2148180 --- /dev/null +++ b/tests/unit/models/config/test_a2a_state_configuration.py @@ -0,0 +1,103 @@ +"""Unit tests for A2AStateConfiguration.""" + +# pylint: disable=no-member + +import pytest +from pydantic import ValidationError + +from models.config import ( + A2AStateConfiguration, + SQLiteDatabaseConfiguration, + PostgreSQLDatabaseConfiguration, +) + + +class TestA2AStateConfiguration: + """Tests for A2AStateConfiguration.""" + + def test_default_configuration(self) -> None: + """Test default configuration is memory type (no database configured).""" + config = A2AStateConfiguration() + + assert config.storage_type == "memory" + assert config.sqlite is None + assert config.postgres is None + assert config.config is None + + def test_sqlite_configuration(self, tmp_path: str) -> None: + """Test SQLite configuration.""" + db_path = f"{tmp_path}/test.db" + sqlite_config = SQLiteDatabaseConfiguration(db_path=db_path) + config = A2AStateConfiguration(sqlite=sqlite_config) + + assert config.storage_type == "sqlite" + assert config.sqlite is not None + assert config.sqlite.db_path == db_path + assert config.config == sqlite_config + + def test_postgres_configuration(self) -> None: + """Test PostgreSQL configuration.""" + postgres_config = PostgreSQLDatabaseConfiguration( + host="localhost", + port=5432, + db="a2a_state", + user="lightspeed", + password="secret", + ) + config = A2AStateConfiguration(postgres=postgres_config) + + assert config.storage_type == "postgres" + assert config.postgres is not None + assert config.postgres.host == "localhost" + assert config.postgres.port == 5432 + assert config.postgres.db == "a2a_state" + assert config.config == postgres_config + + def test_postgres_with_all_options(self) -> None: + """Test PostgreSQL configuration with all options.""" + postgres_config = PostgreSQLDatabaseConfiguration( + host="postgres.example.com", + port=5433, + db="lightspeed", + user="admin", + password="secret123", + namespace="a2a", + ssl_mode="require", + ca_cert_path=None, + ) + config = A2AStateConfiguration(postgres=postgres_config) + + assert config.storage_type == "postgres" + assert config.postgres.host == "postgres.example.com" + assert config.postgres.port == 5433 + assert config.postgres.namespace == "a2a" + assert config.postgres.ssl_mode == "require" + + def test_both_sqlite_and_postgres_raises_error(self, tmp_path: str) -> None: + """Test that configuring both SQLite and PostgreSQL raises ValidationError.""" + db_path = f"{tmp_path}/test.db" + sqlite_config = SQLiteDatabaseConfiguration(db_path=db_path) + postgres_config = PostgreSQLDatabaseConfiguration( + host="localhost", + port=5432, + db="test", + user="test", + password="test", + ) + + with pytest.raises(ValidationError) as exc_info: + A2AStateConfiguration( + sqlite=sqlite_config, + postgres=postgres_config, + ) + + errors = exc_info.value.errors() + assert any( + "Only one A2A state storage configuration can be provided" in str(e["msg"]) + for e in errors + ) + + def test_forbids_extra_fields(self) -> None: + """Test that extra fields are forbidden.""" + with pytest.raises(ValidationError): + A2AStateConfiguration(unknown_field="value") # type: ignore diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 53a7d8b69..cc9fc4461 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -187,6 +187,10 @@ def test_dump_configuration(tmp_path: Path) -> None: "scheduler": {"period": 1}, "enable_token_history": False, }, + "a2a_state": { + "sqlite": None, + "postgres": None, + }, } @@ -503,6 +507,10 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: "scheduler": {"period": 10}, "enable_token_history": True, }, + "a2a_state": { + "sqlite": None, + "postgres": None, + }, } @@ -683,4 +691,8 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "scheduler": {"period": 1}, "enable_token_history": False, }, + "a2a_state": { + "sqlite": None, + "postgres": None, + }, } diff --git a/uv.lock b/uv.lock index 5d6701912..a8b7402cc 100644 --- a/uv.lock +++ b/uv.lock @@ -1380,6 +1380,8 @@ source = { editable = "." } dependencies = [ { name = "a2a-sdk" }, { name = "aiohttp" }, + { name = "aiosqlite" }, + { name = "asyncpg" }, { name = "authlib" }, { name = "cachetools" }, { name = "email-validator" }, @@ -1464,6 +1466,8 @@ llslibdev = [ requires-dist = [ { name = "a2a-sdk", specifier = ">=0.3.4,<0.4.0" }, { name = "aiohttp", specifier = ">=3.12.14" }, + { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "asyncpg", specifier = ">=0.31.0" }, { name = "authlib", specifier = ">=1.6.0" }, { name = "cachetools", specifier = ">=6.1.0" }, { name = "email-validator", specifier = ">=2.2.0" }, From 5864c7a3f5092667e6e5224303c8d76529fc45bf Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 10 Dec 2025 13:19:51 +0100 Subject: [PATCH 5/6] Allow to configure A2A protocol version from the AgentCard --- docs/a2a_protocol.md | 17 +++++++- src/app/endpoints/a2a.py | 6 +-- tests/unit/app/endpoints/test_a2a.py | 59 +++++++++++++++++++++++++++- 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/docs/a2a_protocol.md b/docs/a2a_protocol.md index 5ac36d130..6ef6b6c2f 100644 --- a/docs/a2a_protocol.md +++ b/docs/a2a_protocol.md @@ -86,6 +86,7 @@ Create a separate `agent_card.yaml` file with the agent card configuration: # agent_card.yaml name: "Lightspeed AI Assistant" description: "An AI assistant for OpenShift and Kubernetes" +protocolVersion: "0.3.0" # A2A protocol version (default: "0.3.0") provider: organization: "Red Hat" url: "https://redhat.com" @@ -128,6 +129,7 @@ customization: agent_card_config: name: "My AI Assistant" description: "An AI assistant for helping with various tasks" + protocolVersion: "0.3.0" # A2A protocol version (default: "0.3.0") provider: organization: "My Organization" url: "https://myorg.example.com" @@ -266,7 +268,7 @@ The agent card describes the agent's capabilities: "version": "1.0.0", "url": "https://example.com/a2a", "documentation_url": "https://example.com/docs", - "protocol_version": "0.2.1", + "protocol_version": "0.3.0", "provider": { "organization": "Red Hat", "url": "https://redhat.com" @@ -298,6 +300,8 @@ The agent card describes the agent's capabilities: } ``` +**Note:** The `protocol_version` field can be configured via the `protocolVersion` setting in your agent card configuration (see [Agent Card Configuration](#agent-card-configuration) section above). + ## How the Executor Works ### A2AAgentExecutor @@ -710,7 +714,16 @@ Check logs for entries from `app.endpoints.handlers` logger. ## Protocol Version -This implementation supports A2A protocol version **0.2.1**. +The A2A protocol version can be configured in the agent card configuration file using the `protocolVersion` field. If not specified, it defaults to **0.3.0**. + +To set a specific protocol version, add it to your agent card configuration: + +```yaml +# In agent_card.yaml or customization.agent_card_config +protocolVersion: "0.3.0" +``` + +The protocol version is included in the agent card response and indicates which version of the A2A protocol specification the agent implements. ## References diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index d7baaf51a..25471f12a 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -584,7 +584,7 @@ def get_lightspeed_agent_card() -> AgentCard: default_input_modes=config.get("defaultInputModes", ["text/plain"]), default_output_modes=config.get("defaultOutputModes", ["text/plain"]), capabilities=capabilities, - protocol_version="0.2.1", + protocol_version=config.get("protocolVersion", "0.3.0"), security=config.get("security", [{"bearer": []}]), security_schemes=config.get("security_schemes", {}), ) @@ -845,6 +845,6 @@ async def a2a_health_check() -> dict[str, str]: "status": "healthy", "service": "lightspeed-a2a", "version": __version__, - "a2a_sdk_version": "0.2.1", - "timestamp": datetime.now().isoformat(), + "a2a_sdk_version": "0.3.4", + "timestamp": datetime.now(timezone.utc).isoformat(), } diff --git a/tests/unit/app/endpoints/test_a2a.py b/tests/unit/app/endpoints/test_a2a.py index 265303074..6e3e13be9 100644 --- a/tests/unit/app/endpoints/test_a2a.py +++ b/tests/unit/app/endpoints/test_a2a.py @@ -352,7 +352,7 @@ def test_get_agent_card_with_config( assert agent_card.name == "Test Agent" assert agent_card.description == "A test agent" assert agent_card.url == "http://localhost:8080/a2a" - assert agent_card.protocol_version == "0.2.1" + assert agent_card.protocol_version == "0.3.0" # Default protocol version # Check provider assert agent_card.provider is not None @@ -367,6 +367,63 @@ def test_get_agent_card_with_config( assert agent_card.capabilities is not None assert agent_card.capabilities.streaming is True + def test_get_agent_card_with_custom_protocol_version( + self, mocker: MockerFixture + ) -> None: + """Test getting agent card with custom protocol version.""" + config_dict: dict[Any, Any] = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "base_url": "http://localhost:8080", + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "mcp_servers": [], + "customization": { + "agent_card_config": { + "name": "Test Agent", + "description": "A test agent", + "protocolVersion": "0.2.1", # Custom protocol version + "provider": { + "organization": "Test Org", + "url": "https://test.org", + }, + "skills": [ + { + "id": "test-skill", + "name": "Test Skill", + "description": "A test skill", + "tags": ["test"], + "inputModes": ["text/plain"], + "outputModes": ["text/plain"], + } + ], + "capabilities": { + "streaming": True, + "pushNotifications": False, + "stateTransitionHistory": False, + }, + } + }, + "authentication": {"module": "noop"}, + "authorization": {"access_rules": []}, + "a2a_state": {}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + mocker.patch("app.endpoints.a2a.configuration", cfg) + + agent_card = get_lightspeed_agent_card() + + assert agent_card.protocol_version == "0.2.1" # Custom version used + def test_get_agent_card_without_config_raises_error( self, setup_minimal_configuration: AppConfig, # pylint: disable=unused-argument From 49851e71d697d72cb8e3fcf155695136d8349d5e Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 10 Dec 2025 15:27:35 +0100 Subject: [PATCH 6/6] Better handling A2A issues when connecting to LlamaStack --- src/app/endpoints/a2a.py | 50 +++++++--- tests/unit/app/endpoints/test_a2a.py | 141 +++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 12 deletions(-) diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 25471f12a..3fd2f8552 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -11,6 +11,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStream, ) +from llama_stack_client import APIConnectionError from starlette.responses import Response, StreamingResponse from a2a.types import ( @@ -310,19 +311,44 @@ async def _process_task_streaming( # pylint: disable=too-many-locals # Get LLM client and select model client = AsyncLlamaStackClientHolder().get_client() - llama_stack_model_id, _model_id, _provider_id = select_model_and_provider_id( - await client.models.list(), - *evaluate_model_hints(user_conversation=None, query_request=query_request), - ) + try: + llama_stack_model_id, _model_id, _provider_id = ( + select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints( + user_conversation=None, query_request=query_request + ), + ) + ) - # Stream response from LLM using the Responses API - stream, conversation_id = await retrieve_response( - client, - llama_stack_model_id, - query_request, - self.auth_token, - mcp_headers=self.mcp_headers, - ) + # Stream response from LLM using the Responses API + stream, conversation_id = await retrieve_response( + client, + llama_stack_model_id, + query_request, + self.auth_token, + mcp_headers=self.mcp_headers, + ) + except APIConnectionError as e: + error_message = ( + f"Unable to connect to Llama Stack backend service: {str(e)}. " + "The service may be temporarily unavailable. Please try again later." + ) + logger.error( + "APIConnectionError in A2A request: %s", + str(e), + exc_info=True, + ) + await task_updater.update_status( + TaskState.failed, + message=new_agent_text_message( + error_message, + context_id=context_id, + task_id=task_id, + ), + final=True, + ) + return # Persist conversation_id for next turn in same A2A context if conversation_id: diff --git a/tests/unit/app/endpoints/test_a2a.py b/tests/unit/app/endpoints/test_a2a.py index 6e3e13be9..9a7dd67a4 100644 --- a/tests/unit/app/endpoints/test_a2a.py +++ b/tests/unit/app/endpoints/test_a2a.py @@ -6,8 +6,10 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock +import httpx import pytest from fastapi import HTTPException, Request +from llama_stack_client import APIConnectionError from pytest_mock import MockerFixture from a2a.types import ( @@ -654,6 +656,145 @@ async def test_process_task_streaming_no_input( call_args = task_updater.update_status.call_args assert call_args[0][0] == TaskState.input_required + @pytest.mark.asyncio + async def test_process_task_streaming_handles_api_connection_error_on_models_list( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test _process_task_streaming handles APIConnectionError from models.list().""" + executor = A2AAgentExecutor(auth_token="test-token") + + # Mock the context with valid input + mock_message = MagicMock() + mock_message.role = "user" + mock_message.parts = [Part(root=TextPart(text="Hello"))] + mock_message.metadata = {} + + context = MagicMock(spec=RequestContext) + context.task_id = "task-123" + context.context_id = "ctx-456" + context.message = mock_message + context.get_user_input.return_value = "Hello" + + # Mock event queue + event_queue = AsyncMock(spec=EventQueue) + + # Create task updater mock + task_updater = MagicMock() + task_updater.update_status = AsyncMock() + task_updater.event_queue = event_queue + + # Mock the context store + mock_context_store = AsyncMock() + mock_context_store.get.return_value = None + mocker.patch( + "app.endpoints.a2a._get_context_store", return_value=mock_context_store + ) + + # Mock the client to raise APIConnectionError on models.list() + mock_client = AsyncMock() + # Create a mock httpx.Request for APIConnectionError + mock_request = httpx.Request("GET", "http://test-llama-stack/models") + mock_client.models.list.side_effect = APIConnectionError( + message="Connection refused: unable to reach Llama Stack", + request=mock_request, + ) + mocker.patch( + "app.endpoints.a2a.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + await executor._process_task_streaming( + context, task_updater, context.task_id, context.context_id + ) + + # Verify failure status was sent + task_updater.update_status.assert_called_once() + call_args = task_updater.update_status.call_args + assert call_args[0][0] == TaskState.failed + assert call_args[1]["final"] is True + # Verify error message contains helpful info + error_message = call_args[1]["message"] + assert "Unable to connect to Llama Stack backend service" in str(error_message) + + @pytest.mark.asyncio + async def test_process_task_streaming_handles_api_connection_error_on_retrieve_response( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, # pylint: disable=unused-argument + ) -> None: + """Test _process_task_streaming handles APIConnectionError from retrieve_response().""" + executor = A2AAgentExecutor(auth_token="test-token") + + # Mock the context with valid input + mock_message = MagicMock() + mock_message.role = "user" + mock_message.parts = [Part(root=TextPart(text="Hello"))] + mock_message.metadata = {} + + context = MagicMock(spec=RequestContext) + context.task_id = "task-123" + context.context_id = "ctx-456" + context.message = mock_message + context.get_user_input.return_value = "Hello" + + # Mock event queue + event_queue = AsyncMock(spec=EventQueue) + + # Create task updater mock + task_updater = MagicMock() + task_updater.update_status = AsyncMock() + task_updater.event_queue = event_queue + + # Mock the context store + mock_context_store = AsyncMock() + mock_context_store.get.return_value = None + mocker.patch( + "app.endpoints.a2a._get_context_store", return_value=mock_context_store + ) + + # Mock the client to succeed on models.list() + mock_client = AsyncMock() + mock_models = MagicMock() + mock_models.models = [] + mock_client.models.list.return_value = mock_models + mocker.patch( + "app.endpoints.a2a.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + # Mock select_model_and_provider_id + mocker.patch( + "app.endpoints.a2a.select_model_and_provider_id", + return_value=("model-id", "model-id", "provider-id"), + ) + + # Mock evaluate_model_hints + mocker.patch( + "app.endpoints.a2a.evaluate_model_hints", return_value=(None, None) + ) + + # Mock retrieve_response to raise APIConnectionError + mock_request = httpx.Request("POST", "http://test-llama-stack/responses") + mocker.patch( + "app.endpoints.a2a.retrieve_response", + side_effect=APIConnectionError( + message="Connection timeout during streaming", request=mock_request + ), + ) + + await executor._process_task_streaming( + context, task_updater, context.task_id, context.context_id + ) + + # Verify failure status was sent + task_updater.update_status.assert_called_once() + call_args = task_updater.update_status.call_args + assert call_args[0][0] == TaskState.failed + assert call_args[1]["final"] is True + # Verify error message contains helpful info + error_message = call_args[1]["message"] + assert "Unable to connect to Llama Stack backend service" in str(error_message) + @pytest.mark.asyncio async def test_cancel_raises_not_implemented(self) -> None: """Test that cancel raises NotImplementedError."""