-
Notifications
You must be signed in to change notification settings - Fork 0
Pass external API tokens to MCP tools via X-External-Token header using ADK session #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
4fcc1e7
aec4d84
1739724
ca6dd0f
279b4ed
23ddb1e
20bc6e8
1d0fec4
0fcd4f5
34e5cc2
71512aa
a6178e2
d2992a2
1fca353
78a2f88
bcde875
f82d56a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,11 +7,13 @@ | |
| import logging | ||
| from typing import AsyncIterator, Awaitable, Callable | ||
|
|
||
| from a2a.server.agent_execution.context import RequestContext | ||
| from a2a.server.apps import A2AStarletteApplication | ||
| from a2a.server.request_handlers import DefaultRequestHandler | ||
| from a2a.server.tasks import InMemoryTaskStore | ||
| from a2a.types import AgentCapabilities, AgentCard | ||
| from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH | ||
| from google.adk.a2a.converters.request_converter import AgentRunRequest | ||
| from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor | ||
| from google.adk.agents import LlmAgent | ||
| from google.adk.agents.base_agent import BaseAgent | ||
|
|
@@ -21,15 +23,75 @@ | |
| from google.adk.memory.in_memory_memory_service import InMemoryMemoryService | ||
| from google.adk.runners import Runner | ||
| from google.adk.sessions.in_memory_session_service import InMemorySessionService | ||
| from google.adk.sessions.session import Session | ||
| from starlette.applications import Starlette | ||
|
|
||
| from .agent import AgentFactory | ||
| from .callback_tracer_plugin import CallbackTracerPlugin | ||
| from .config import McpTool, SubAgent | ||
| from .constants import EXTERNAL_TOKEN_SESSION_KEY | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class TokenCapturingA2aAgentExecutor(A2aAgentExecutor): | ||
| """Custom A2A agent executor that captures and stores the X-External-Token header. | ||
|
|
||
| This executor extends the standard A2aAgentExecutor to intercept the request | ||
| and store the X-External-Token header in the ADK session state. This allows | ||
| MCP tools to access the token via the header_provider hook, using ADK's | ||
| built-in session management rather than external context variables. | ||
| """ | ||
|
|
||
| async def _prepare_session( | ||
| self, | ||
| context: RequestContext, | ||
| run_request: AgentRunRequest, | ||
| runner: Runner, | ||
| ) -> Session: | ||
| """Prepare the session and store the external token if present. | ||
|
|
||
| This method extends the parent implementation to capture the X-External-Token | ||
| header from the request context and store it in the session state. | ||
|
|
||
| Args: | ||
| context: The A2A request context containing the call context with headers | ||
| run_request: The agent run request | ||
| runner: The ADK runner instance | ||
|
|
||
| Returns: | ||
| The prepared session with the external token stored in its state | ||
| """ | ||
| # Call parent to get or create the session | ||
| session: Session = await super()._prepare_session(context, run_request, runner) | ||
|
|
||
| # Extract the X-External-Token header from the request context | ||
| # The call_context.state contains headers from the original HTTP request | ||
| if context.call_context and "headers" in context.call_context.state: | ||
| headers = context.call_context.state["headers"] | ||
| # Headers might be in different cases, check all variations | ||
| external_token = ( | ||
| headers.get("x-external-token") | ||
| or headers.get("X-External-Token") | ||
| or headers.get("X-EXTERNAL-TOKEN") | ||
| ) | ||
|
|
||
| if external_token: | ||
| # Store the token in the session state | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main issue here is, that we rely too much on the session, while we actually want to store the token in the session state. Look for other injection points where we can actually modify the state "normally", as intended by ADK. See: https://google.github.io/adk-docs/sessions/state/#how-state-is-updated-recommended-methods
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in commit 78a2f88. Now using ADK's recommended approach: creating an Event with EventActions containing state_delta, then calling |
||
| # NOTE: InMemorySessionService returns copies of sessions, so we need to | ||
| # update the internal storage directly | ||
| session.state[EXTERNAL_TOKEN_SESSION_KEY] = external_token | ||
|
|
||
| # Update the stored session directly (InMemorySessionService returns copies) | ||
| if hasattr(runner.session_service, "sessions"): | ||
|
g3force marked this conversation as resolved.
Outdated
|
||
| stored_session = runner.session_service.sessions.get(session.app_name, {}).get(session.user_id, {}).get(session.id) | ||
| if stored_session: | ||
| stored_session.state[EXTERNAL_TOKEN_SESSION_KEY] = external_token | ||
| logger.debug("Stored external token in session %s", session.id) | ||
|
|
||
| return session | ||
|
|
||
|
|
||
| class HealthCheckFilter(logging.Filter): | ||
| def filter(self, record: logging.LogRecord) -> bool: | ||
| # Check if the log message contains the well known path of the card, which is used for health checks | ||
|
|
@@ -55,15 +117,16 @@ async def create_runner() -> Runner: | |
| plugins=[CallbackTracerPlugin()], | ||
| ), | ||
| artifact_service=InMemoryArtifactService(), | ||
| session_service=InMemorySessionService(), # type: ignore | ||
| memory_service=InMemoryMemoryService(), # type: ignore | ||
| credential_service=InMemoryCredentialService(), # type: ignore | ||
| session_service=InMemorySessionService(), # type: ignore[no-untyped-call] | ||
| memory_service=InMemoryMemoryService(), # type: ignore[no-untyped-call] | ||
| credential_service=InMemoryCredentialService(), # type: ignore[no-untyped-call] | ||
| ) | ||
|
|
||
| # Create A2A components | ||
| task_store = InMemoryTaskStore() | ||
|
|
||
| agent_executor = A2aAgentExecutor( | ||
| # Use custom executor that captures X-External-Token and stores in session | ||
| agent_executor = TokenCapturingA2aAgentExecutor( | ||
| runner=create_runner, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| """Constants shared across the agenticlayer package.""" | ||
|
|
||
| # Key used to store the external token in the ADK session state | ||
| EXTERNAL_TOKEN_SESSION_KEY = "__external_token__" # nosec B105 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| """Tests for external token passing to MCP tools via ADK session.""" | ||
|
|
||
| from agenticlayer.agent import _get_mcp_headers_from_session | ||
| from agenticlayer.constants import EXTERNAL_TOKEN_SESSION_KEY | ||
| from google.adk.sessions.session import Session | ||
|
|
||
|
|
||
| def test_header_provider_retrieves_token_from_session() -> None: | ||
| """Test that the header provider function can retrieve token from session state.""" | ||
| # Given: A session with an external token stored | ||
| test_token = "test-api-token-xyz" # nosec B105 | ||
| session = Session( | ||
| id="test-session", | ||
| app_name="test-app", | ||
| user_id="test-user", | ||
| state={EXTERNAL_TOKEN_SESSION_KEY: test_token}, | ||
| events=[], | ||
| last_update_time=0.0, | ||
| ) | ||
|
|
||
| # Create a mock readonly context | ||
| class MockReadonlyContext: | ||
| def __init__(self, session: Session) -> None: | ||
| self.session = session | ||
|
|
||
| readonly_context = MockReadonlyContext(session) | ||
|
|
||
| # When: Calling the header provider function | ||
| headers = _get_mcp_headers_from_session(readonly_context) # type: ignore[arg-type] | ||
|
|
||
| # Then: The headers should include the X-External-Token | ||
| assert headers == {"X-External-Token": test_token} | ||
|
|
||
|
|
||
| def test_header_provider_returns_empty_when_no_token() -> None: | ||
| """Test that the header provider returns empty dict when no token is present.""" | ||
| # Given: A session without an external token | ||
| session = Session( | ||
| id="test-session", | ||
| app_name="test-app", | ||
| user_id="test-user", | ||
| state={}, # No token | ||
| events=[], | ||
| last_update_time=0.0, | ||
| ) | ||
|
|
||
| # Create a mock readonly context | ||
| class MockReadonlyContext: | ||
| def __init__(self, session: Session) -> None: | ||
| self.session = session | ||
|
|
||
| readonly_context = MockReadonlyContext(session) | ||
|
|
||
| # When: Calling the header provider function | ||
| headers = _get_mcp_headers_from_session(readonly_context) # type: ignore[arg-type] | ||
|
|
||
| # Then: The headers should be empty | ||
| assert headers == {} | ||
|
|
||
|
|
||
| def test_header_provider_handles_none_context() -> None: | ||
| """Test that the header provider safely handles None context.""" | ||
| # When: Calling the header provider with None | ||
| headers = _get_mcp_headers_from_session(None) # type: ignore[arg-type] | ||
|
|
||
| # Then: The headers should be empty (no exception) | ||
| assert headers == {} | ||
|
|
Uh oh!
There was an error while loading. Please reload this page.