diff --git a/adk/README.md b/adk/README.md index 2556b23..c404d9b 100644 --- a/adk/README.md +++ b/adk/README.md @@ -9,6 +9,7 @@ SDK for Google ADK that helps to get agents configured in the Agentic Layer quic - Configures A2A protocol for inter-agent communication - Offers parsing methods for sub agents and tools - Set log level via env var `LOGLEVEL` (default: `INFO`) +- Automatically passes external API tokens to MCP tools via the `X-External-Token` header ## Usage @@ -100,3 +101,47 @@ Body logging behavior: **Note**: Starlette body logging is more limited than HTTPX because it must avoid consuming request/response streams. Bodies are only captured when already buffered in the ASGI scope. + +## External API Token Passing + +The SDK supports passing external API tokens from A2A requests to MCP tools. This enables MCP servers to authenticate with external APIs on behalf of users. + +### How It Works + +1. **Token Capture**: When an A2A request includes the `X-External-Token` header, the SDK automatically captures and stores it in the ADK session state +2. **Secure Storage**: The token is stored in ADK's session state (not in memory state accessible to the LLM), ensuring the agent cannot directly access or leak it +3. **Automatic Injection**: When MCP tools are invoked, the SDK uses ADK's `header_provider` hook to retrieve the token from the session and inject it as the `X-External-Token` header in tool requests + +**Current Limitations**: The token is only passed to MCP servers. Propagation to sub-agents is not currently supported due to ADK limitations in passing custom HTTP headers in A2A requests. + +### Usage Example + +Simply include the `X-External-Token` header in your A2A requests: + +```bash +curl -X POST http://localhost:8000/ \ + -H "Content-Type: application/json" \ + -H "X-External-Token: your-api-token-here" \ + -d '{ + "jsonrpc": "2.0", + "id": 1, + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Your message"}], + "messageId": "msg-123", + "contextId": "ctx-123" + } + } + }' +``` + +The SDK will automatically pass `your-api-token-here` to all MCP tool calls and sub-agent requests made during that session. + +### Security Considerations + +- Tokens are stored in ADK session state (separate from memory state that the LLM can access) +- Tokens are not directly accessible to agent code through normal session state queries +- Tokens persist for the session duration and are managed by ADK's session lifecycle +- This is a simple authentication mechanism; for production use, consider implementing more sophisticated authentication and authorization schemes diff --git a/adk/agenticlayer/agent.py b/adk/agenticlayer/agent.py index a446055..2686463 100644 --- a/adk/agenticlayer/agent.py +++ b/adk/agenticlayer/agent.py @@ -9,6 +9,7 @@ from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH from google.adk.agents import BaseAgent, LlmAgent from google.adk.agents.llm_agent import ToolUnion +from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.tools.agent_tool import AgentTool from google.adk.tools.mcp_tool import StreamableHTTPConnectionParams @@ -16,10 +17,33 @@ from httpx_retries import Retry, RetryTransport from agenticlayer.config import InteractionType, McpTool, SubAgent +from agenticlayer.constants import EXTERNAL_TOKEN_SESSION_KEY logger = logging.getLogger(__name__) +def _get_mcp_headers_from_session(readonly_context: ReadonlyContext) -> dict[str, str]: + """Header provider function for MCP tools that retrieves token from ADK session. + + This function is called by the ADK when MCP tools are invoked. It reads the + external token from the session state where it was stored during request + processing by TokenCapturingA2aAgentExecutor. + + Args: + readonly_context: The ADK ReadonlyContext providing access to the session + + Returns: + A dictionary of headers to include in MCP tool requests. + If a token is stored in the session, includes it in the headers. + """ + # Access the session state directly from the readonly context + if readonly_context and readonly_context.state: + external_token = readonly_context.state.get(EXTERNAL_TOKEN_SESSION_KEY) + if external_token: + return {"X-External-Token": external_token} + return {} + + class AgentFactory: def __init__( self, @@ -110,6 +134,8 @@ def load_tools(self, mcp_tools: list[McpTool]) -> list[ToolUnion]: url=str(tool.url), timeout=tool.timeout, ), + # Provide header provider to inject session-stored token into tool requests + header_provider=_get_mcp_headers_from_session, ) ) diff --git a/adk/agenticlayer/agent_to_a2a.py b/adk/agenticlayer/agent_to_a2a.py index ee1d29c..8c5ad38 100644 --- a/adk/agenticlayer/agent_to_a2a.py +++ b/adk/agenticlayer/agent_to_a2a.py @@ -7,29 +7,96 @@ import logging from typing import AsyncIterator, Awaitable, Callable +from a2a.server.agent_execution.context import RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore from a2a.types import AgentCapabilities, AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.agents import LlmAgent from google.adk.agents.base_agent import BaseAgent from google.adk.apps.app import App from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session from starlette.applications import Starlette from .agent import AgentFactory from .callback_tracer_plugin import CallbackTracerPlugin from .config import McpTool, SubAgent +from .constants import EXTERNAL_TOKEN_SESSION_KEY logger = logging.getLogger(__name__) +class TokenCapturingA2aAgentExecutor(A2aAgentExecutor): + """Custom A2A agent executor that captures and stores the X-External-Token header. + + This executor extends the standard A2aAgentExecutor to intercept the request + and store the X-External-Token header in the ADK session state. This allows + MCP tools to access the token via the header_provider hook, using ADK's + built-in session management rather than external context variables. + """ + + async def _prepare_session( + self, + context: RequestContext, + run_request: AgentRunRequest, + runner: Runner, + ) -> Session: + """Prepare the session and store the external token if present. + + This method extends the parent implementation to capture the X-External-Token + header from the request context and store it in the session state using ADK's + recommended approach: creating an Event with state_delta and appending it to + the session. + + Args: + context: The A2A request context containing the call context with headers + run_request: The agent run request + runner: The ADK runner instance + + Returns: + The prepared session with the external token stored in its state + """ + # Call parent to get or create the session + session: Session = await super()._prepare_session(context, run_request, runner) + + # Extract the X-External-Token header from the request context + # The call_context.state contains headers from the original HTTP request + if context.call_context and "headers" in context.call_context.state: + headers = context.call_context.state["headers"] + # Headers might be in different cases, check all variations + external_token = ( + headers.get("x-external-token") + or headers.get("X-External-Token") + or headers.get("X-EXTERNAL-TOKEN") + ) + + if external_token: + # Store the token in the session state using ADK's recommended method: + # Create an Event with a state_delta and append it to the session. + # This follows ADK's pattern for updating session state as documented at: + # https://google.github.io/adk-docs/sessions/state/#how-state-is-updated-recommended-methods + event = Event( + author="system", + actions=EventActions( + state_delta={EXTERNAL_TOKEN_SESSION_KEY: external_token} + ) + ) + await runner.session_service.append_event(session, event) + logger.debug("Stored external token in session %s via state_delta", session.id) + + return session + + class HealthCheckFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: # Check if the log message contains the well known path of the card, which is used for health checks @@ -55,15 +122,16 @@ async def create_runner() -> Runner: plugins=[CallbackTracerPlugin()], ), artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), # type: ignore - memory_service=InMemoryMemoryService(), # type: ignore - credential_service=InMemoryCredentialService(), # type: ignore + session_service=InMemorySessionService(), # type: ignore[no-untyped-call] + memory_service=InMemoryMemoryService(), # type: ignore[no-untyped-call] + credential_service=InMemoryCredentialService(), # type: ignore[no-untyped-call] ) # Create A2A components task_store = InMemoryTaskStore() - agent_executor = A2aAgentExecutor( + # Use custom executor that captures X-External-Token and stores in session + agent_executor = TokenCapturingA2aAgentExecutor( runner=create_runner, ) diff --git a/adk/agenticlayer/constants.py b/adk/agenticlayer/constants.py new file mode 100644 index 0000000..cfa8e40 --- /dev/null +++ b/adk/agenticlayer/constants.py @@ -0,0 +1,4 @@ +"""Constants shared across the agenticlayer package.""" + +# Key used to store the external token in the ADK session state +EXTERNAL_TOKEN_SESSION_KEY = "__external_token__" # nosec B105 diff --git a/adk/tests/test_agent_integration.py b/adk/tests/test_agent_integration.py index a0e1174..14da1fe 100644 --- a/adk/tests/test_agent_integration.py +++ b/adk/tests/test_agent_integration.py @@ -10,7 +10,7 @@ from agenticlayer.agent_to_a2a import to_a2a from agenticlayer.config import InteractionType, McpTool, SubAgent from asgi_lifespan import LifespanManager -from fastmcp import FastMCP +from fastmcp import Context, FastMCP from httpx_retries import Retry from pydantic import AnyHttpUrl from starlette.testclient import TestClient @@ -342,3 +342,104 @@ def add(a: int, b: int) -> int: assert history[4]["role"] == "agent" assert history[4]["parts"] == [{"kind": "text", "text": "The calculation result is correct!"}] + + @pytest.mark.asyncio + async def test_external_token_passed_to_mcp_tools( + self, + app_factory: Any, + agent_factory: Any, + llm_controller: LLMMockController, + respx_mock: respx.MockRouter, + ) -> None: + """Test that X-External-Token header is passed from A2A request to MCP tool calls. + + Verifies end-to-end token passing through the agent to MCP servers. + """ + + # Given: Mock LLM to call 'echo' tool + llm_controller.respond_with_tool_call( + pattern="", # Match any message + tool_name="echo", + tool_args={"message": "test"}, + final_message="Echo completed!", + ) + + # Given: MCP server with 'echo' tool that can access headers via Context + mcp = FastMCP("TokenVerifier") + received_headers: list[dict[str, str]] = [] + received_tokens_in_tool: list[str | None] = [] + + @mcp.tool() + def echo(message: str, ctx: Context) -> str: + """Echo a message and verify token is accessible in tool context.""" + # Access headers from the MCP request context + # The Context object provides access to the request_context which includes HTTP headers + if ctx.request_context and hasattr(ctx.request_context, "request"): + # Try to get the token from request headers if available + request = ctx.request_context.request + if request and hasattr(request, "headers"): + token = request.headers.get("x-external-token") or request.headers.get("X-External-Token") + received_tokens_in_tool.append(token) + return f"Echoed: {message}" + + mcp_server_url = "http://test-mcp-token.local" + mcp_app = mcp.http_app(path="/mcp") + + async with LifespanManager(mcp_app) as mcp_manager: + # Create a custom handler that captures headers + async def handler_with_header_capture(request: httpx.Request) -> httpx.Response: + # Capture the headers from the request + received_headers.append(dict(request.headers)) + + # Forward to the MCP app + transport = httpx.ASGITransport(app=mcp_manager.app) + async with httpx.AsyncClient(transport=transport, base_url=mcp_server_url) as client: + return await client.request( + method=request.method, + url=str(request.url), + headers=request.headers, + content=request.content, + ) + + # Route MCP requests through our custom handler + respx_mock.route(host="test-mcp-token.local").mock(side_effect=handler_with_header_capture) + + # When: Create agent with MCP tool and send request with X-External-Token header + test_agent = agent_factory("test_agent") + tools = [McpTool(name="verifier", url=AnyHttpUrl(f"{mcp_server_url}/mcp"), timeout=30)] + external_token = "secret-api-token-12345" # nosec B105 + + async with app_factory(test_agent, tools=tools) as app: + client = TestClient(app) + user_message = "Echo test message" + response = client.post( + "", + json=create_send_message_request(user_message), + headers={"X-External-Token": external_token}, + ) + + # Then: Verify response is successful + assert response.status_code == 200 + result = verify_jsonrpc_response(response.json()) + assert result["status"]["state"] == "completed", "Task should complete successfully" + + # Then: Verify X-External-Token header was passed to MCP server + assert len(received_headers) > 0, "MCP server should have received requests" + + # Find the tool call request (not the initialization requests) + # Header keys might be lowercase + tool_call_headers = [h for h in received_headers if "x-external-token" in h or "X-External-Token" in h] + assert len(tool_call_headers) > 0, ( + f"At least one request should have X-External-Token header. " + f"Received {len(received_headers)} requests total." + ) + + # Verify the token value + for headers in tool_call_headers: + # Header might be lowercase in the dict + token_value = headers.get("X-External-Token") or headers.get("x-external-token") + assert token_value == external_token, ( + f"Expected token '{external_token}', got '{token_value}'" + ) + + diff --git a/adk/tests/test_external_token.py b/adk/tests/test_external_token.py new file mode 100644 index 0000000..96f969f --- /dev/null +++ b/adk/tests/test_external_token.py @@ -0,0 +1,70 @@ +"""Tests for external token passing to MCP tools via ADK session.""" + +from agenticlayer.agent import _get_mcp_headers_from_session +from agenticlayer.constants import EXTERNAL_TOKEN_SESSION_KEY +from google.adk.sessions.session import Session + + +def test_header_provider_retrieves_token_from_session() -> None: + """Test that the header provider function can retrieve token from session state.""" + # Given: A session with an external token stored + test_token = "test-api-token-xyz" # nosec B105 + session = Session( + id="test-session", + app_name="test-app", + user_id="test-user", + state={EXTERNAL_TOKEN_SESSION_KEY: test_token}, + events=[], + last_update_time=0.0, + ) + + # Create a mock readonly context + class MockReadonlyContext: + def __init__(self, session: Session) -> None: + self.session = session + self.state = session.state # Add state property for direct access + + readonly_context = MockReadonlyContext(session) + + # When: Calling the header provider function + headers = _get_mcp_headers_from_session(readonly_context) # type: ignore[arg-type] + + # Then: The headers should include the X-External-Token + assert headers == {"X-External-Token": test_token} + + +def test_header_provider_returns_empty_when_no_token() -> None: + """Test that the header provider returns empty dict when no token is present.""" + # Given: A session without an external token + session = Session( + id="test-session", + app_name="test-app", + user_id="test-user", + state={}, # No token + events=[], + last_update_time=0.0, + ) + + # Create a mock readonly context + class MockReadonlyContext: + def __init__(self, session: Session) -> None: + self.session = session + self.state = session.state # Add state property for direct access + + readonly_context = MockReadonlyContext(session) + + # When: Calling the header provider function + headers = _get_mcp_headers_from_session(readonly_context) # type: ignore[arg-type] + + # Then: The headers should be empty + assert headers == {} + + +def test_header_provider_handles_none_context() -> None: + """Test that the header provider safely handles None context.""" + # When: Calling the header provider with None + headers = _get_mcp_headers_from_session(None) # type: ignore[arg-type] + + # Then: The headers should be empty (no exception) + assert headers == {} +