Skip to content

Commit 89aae43

Browse files
committed
feat: Add only propagated HTTP headers to session state
1 parent 48a45fc commit 89aae43

File tree

4 files changed

+49
-49
lines changed

4 files changed

+49
-49
lines changed

adk/agenticlayer/agent.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,16 @@ def _create_header_provider(propagate_headers: list[str]) -> Callable[[ReadonlyC
3030
the MCP server's configuration. Only headers listed in propagate_headers will
3131
be included in requests to that server.
3232
33-
The matching is case-insensitive: if the configuration specifies 'Authorization'
34-
and the incoming request has 'authorization', they will match. The output header
35-
will use the case specified in the configuration.
33+
Headers are stored in the session state as flat primitive string keys of the form
34+
``f"{HTTP_HEADERS_SESSION_KEY}.{header_name_lower}"``. The provider looks up each
35+
configured header by its lower-cased key and returns it with the casing specified
36+
in the configuration.
3637
3738
Example:
3839
>>> provider = _create_header_provider(['Authorization', 'X-API-Key'])
39-
>>> # If session has: {'authorization': 'Bearer token', 'x-api-key': 'key123'}
40+
>>> # Session state has: {'http_headers.authorization': 'Bearer token', '__http_headers__.x-api-key': 'key123'}
4041
>>> # Output will be: {'Authorization': 'Bearer token', 'X-API-Key': 'key123'}
4142
42-
Note: If multiple headers with different casing match a single configured header
43-
(e.g., both 'authorization' and 'Authorization' in stored headers), only one
44-
will be included. The last match found will be used.
45-
4643
Args:
4744
propagate_headers: List of header names to propagate to this MCP server
4845
@@ -51,26 +48,15 @@ def _create_header_provider(propagate_headers: list[str]) -> Callable[[ReadonlyC
5148
"""
5249

5350
def header_provider(readonly_context: ReadonlyContext) -> dict[str, str]:
54-
"""Header provider that filters headers based on server configuration."""
51+
"""Header provider that reads per-header flat primitive keys from session state."""
5552
if not readonly_context or not readonly_context.state:
5653
return {}
5754

58-
# Get all stored headers from session
59-
all_headers = readonly_context.state.get(HTTP_HEADERS_SESSION_KEY, {})
60-
if not all_headers:
61-
return {}
62-
63-
# Create a lowercase lookup dictionary for O(n+m) complexity instead of O(n*m)
64-
all_headers_lower = {k.lower(): (k, v) for k, v in all_headers.items()}
65-
66-
# Filter to only include configured headers (case-insensitive matching)
6755
result_headers = {}
6856
for header_name in propagate_headers:
69-
# Try to find the header in the stored headers (case-insensitive)
70-
header_lower = header_name.lower()
71-
if header_lower in all_headers_lower:
72-
original_key, value = all_headers_lower[header_lower]
73-
# Use the original case from the configuration
57+
key = f"{HTTP_HEADERS_SESSION_KEY}.{header_name.lower()}"
58+
value = readonly_context.state.get(key)
59+
if value is not None:
7460
result_headers[header_name] = value
7561

7662
return result_headers

adk/agenticlayer/agent_to_a2a.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import contextlib
77
import logging
8-
from typing import AsyncIterator, Awaitable, Callable
8+
from typing import Any, AsyncIterator, Awaitable, Callable
99

1010
from a2a.server.agent_execution.context import RequestContext
1111
from a2a.server.apps import A2AStarletteApplication
@@ -40,48 +40,51 @@ class HeaderCapturingA2aAgentExecutor(A2aAgentExecutor):
4040
"""Custom A2A agent executor that captures and stores HTTP headers.
4141
4242
This executor extends the standard A2aAgentExecutor to intercept the request
43-
and store all HTTP headers in the ADK session state. This allows MCP tools
44-
to access headers via the header_provider hook, using ADK's built-in session
45-
management rather than external context variables.
43+
and store a filtered set of HTTP headers in the ADK session state. Only headers
44+
that are configured to be propagated (across all MCP tools) are stored, avoiding
45+
accidental capture of sensitive headers. Each header is stored as a separate flat
46+
string entry (primitive type) to remain compatible with OpenTelemetry.
4647
"""
4748

49+
def __init__(self, propagate_headers: set[str], **kwargs: Any) -> None:
50+
super().__init__(**kwargs)
51+
self._propagate_headers_lower = {h.lower() for h in propagate_headers}
52+
4853
async def _prepare_session(
4954
self,
5055
context: RequestContext,
5156
run_request: AgentRunRequest,
5257
runner: Runner,
5358
) -> Session:
54-
"""Prepare the session and store HTTP headers if present.
59+
"""Prepare the session and store filtered HTTP headers as flat primitive keys.
5560
56-
This method extends the parent implementation to capture HTTP headers
57-
from the request context and store them in the session state using ADK's
58-
recommended approach: creating an Event with state_delta and appending it to
59-
the session.
61+
Only headers listed in the configured propagate_headers set are stored.
62+
Each header is stored as a separate string value under the key
63+
``f"{HTTP_HEADERS_SESSION_KEY}.{header_name_lower}"`` to keep session
64+
state values as primitives (required by OpenTelemetry).
6065
6166
Args:
6267
context: The A2A request context containing the call context with headers
6368
run_request: The agent run request
6469
runner: The ADK runner instance
6570
6671
Returns:
67-
The prepared session with HTTP headers stored in its state
72+
The prepared session with filtered HTTP headers stored in its state
6873
"""
69-
# Call parent to get or create the session
7074
session: Session = await super()._prepare_session(context, run_request, runner)
7175

72-
# Extract HTTP headers from the request context
73-
# The call_context.state contains headers from the original HTTP request
74-
if context.call_context and "headers" in context.call_context.state:
76+
if self._propagate_headers_lower and context.call_context and "headers" in context.call_context.state:
7577
headers = context.call_context.state["headers"]
76-
77-
# Store all headers in session state for per-MCP-server filtering
78-
# This allows each MCP server to receive only the headers it's configured to receive
7978
if headers:
80-
event = Event(
81-
author="system", actions=EventActions(state_delta={HTTP_HEADERS_SESSION_KEY: dict(headers)})
82-
)
83-
await runner.session_service.append_event(session, event)
84-
logger.debug("Stored HTTP headers in session %s via state_delta", session.id)
79+
state_delta: dict[str, object] = {}
80+
for key, value in headers.items():
81+
if key.lower() in self._propagate_headers_lower:
82+
state_delta[f"{HTTP_HEADERS_SESSION_KEY}.{key.lower()}"] = value
83+
84+
if state_delta:
85+
event = Event(author="system", actions=EventActions(state_delta=state_delta))
86+
await runner.session_service.append_event(session, event)
87+
logger.debug("Stored %d HTTP headers in session %s via state_delta", len(state_delta), session.id)
8588

8689
return session
8790

@@ -92,12 +95,17 @@ def filter(self, record: logging.LogRecord) -> bool:
9295
return record.getMessage().find(AGENT_CARD_WELL_KNOWN_PATH) == -1
9396

9497

95-
async def create_a2a_app(agent: BaseAgent, rpc_url: str) -> A2AStarletteApplication:
98+
async def create_a2a_app(
99+
agent: BaseAgent, rpc_url: str, propagate_headers: set[str] | None = None
100+
) -> A2AStarletteApplication:
96101
"""Create an A2A Starlette application from an ADK agent.
97102
98103
Args:
99104
agent: The ADK agent to convert
100105
rpc_url: The URL where the agent will be available for A2A communication
106+
propagate_headers: Union of all header names that any MCP tool is configured to
107+
propagate. Only these headers will be captured from incoming requests and
108+
stored in the session state.
101109
Returns:
102110
An A2AStarletteApplication instance
103111
"""
@@ -122,6 +130,7 @@ async def create_runner() -> Runner:
122130
# Use custom executor that captures HTTP headers and stores in session
123131
agent_executor = HeaderCapturingA2aAgentExecutor(
124132
runner=create_runner,
133+
propagate_headers=propagate_headers or set(),
125134
)
126135

127136
request_handler = DefaultRequestHandler(agent_executor=agent_executor, task_store=task_store)
@@ -175,14 +184,15 @@ def to_a2a(
175184
"""
176185

177186
agent_factory = agent_factory or AgentFactory()
187+
all_propagate_headers = {h for tool in (tools or []) for h in tool.propagate_headers}
178188

179189
async def a2a_app_creator() -> A2AStarletteApplication:
180190
configured_agent = await agent_factory.load_agent(
181191
agent=agent,
182192
sub_agents=sub_agents or [],
183193
tools=tools or [],
184194
)
185-
return await create_a2a_app(configured_agent, rpc_url)
195+
return await create_a2a_app(configured_agent, rpc_url, propagate_headers=all_propagate_headers)
186196

187197
return to_starlette(a2a_app_creator)
188198

adk/agenticlayer/constants.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Constants shared across the agenticlayer package."""
22

3-
# Key used to store all HTTP headers in the ADK session state
4-
HTTP_HEADERS_SESSION_KEY = "__http_headers__" # nosec B105
3+
# Prefix used to store propagated HTTP headers in ADK session state as flat primitive keys.
4+
# Each header is stored as a separate string entry: f"{HTTP_HEADERS_SESSION_KEY}.{header_name_lower}"
5+
# e.g. "http_headers.authorization" -> "Bearer token"
6+
HTTP_HEADERS_SESSION_KEY = "http_headers"

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ filterwarnings = [
5959
"ignore::DeprecationWarning:litellm.*",
6060
"ignore::DeprecationWarning:a2a.*",
6161
"ignore::UserWarning:agenticlayer.*",
62+
# google-adk uses the deprecated streamablehttp_client from mcp; suppress until upstream fixes it
63+
"ignore:Use `streamable_http_client` instead.:DeprecationWarning",
6264
]
6365

6466

0 commit comments

Comments
 (0)