55
66import contextlib
77import logging
8- from typing import AsyncIterator , Awaitable , Callable
8+ from typing import Any , AsyncIterator , Awaitable , Callable
99
1010from a2a .server .agent_execution .context import RequestContext
1111from a2a .server .apps import A2AStarletteApplication
@@ -40,48 +40,51 @@ class HeaderCapturingA2aAgentExecutor(A2aAgentExecutor):
4040 """Custom A2A agent executor that captures and stores HTTP headers.
4141
4242 This executor extends the standard A2aAgentExecutor to intercept the request
43- and store all HTTP headers in the ADK session state. This allows MCP tools
44- to access headers via the header_provider hook, using ADK's built-in session
45- management rather than external context variables.
43+ and store a filtered set of HTTP headers in the ADK session state. Only headers
44+ that are configured to be propagated (across all MCP tools) are stored, avoiding
45+ accidental capture of sensitive headers. Each header is stored as a separate flat
46+ string entry (primitive type) to remain compatible with OpenTelemetry.
4647 """
4748
49+ def __init__ (self , propagate_headers : set [str ], ** kwargs : Any ) -> None :
50+ super ().__init__ (** kwargs )
51+ self ._propagate_headers_lower = {h .lower () for h in propagate_headers }
52+
4853 async def _prepare_session (
4954 self ,
5055 context : RequestContext ,
5156 run_request : AgentRunRequest ,
5257 runner : Runner ,
5358 ) -> Session :
54- """Prepare the session and store HTTP headers if present .
59+ """Prepare the session and store filtered HTTP headers as flat primitive keys .
5560
56- This method extends the parent implementation to capture HTTP headers
57- from the request context and store them in the session state using ADK's
58- recommended approach: creating an Event with state_delta and appending it to
59- the session .
61+ Only headers listed in the configured propagate_headers set are stored.
62+ Each header is stored as a separate string value under the key
63+ ``f"{HTTP_HEADERS_SESSION_KEY}.{header_name_lower}"`` to keep session
64+ state values as primitives (required by OpenTelemetry) .
6065
6166 Args:
6267 context: The A2A request context containing the call context with headers
6368 run_request: The agent run request
6469 runner: The ADK runner instance
6570
6671 Returns:
67- The prepared session with HTTP headers stored in its state
72+ The prepared session with filtered HTTP headers stored in its state
6873 """
69- # Call parent to get or create the session
7074 session : Session = await super ()._prepare_session (context , run_request , runner )
7175
72- # Extract HTTP headers from the request context
73- # The call_context.state contains headers from the original HTTP request
74- if context .call_context and "headers" in context .call_context .state :
76+ if self ._propagate_headers_lower and context .call_context and "headers" in context .call_context .state :
7577 headers = context .call_context .state ["headers" ]
76-
77- # Store all headers in session state for per-MCP-server filtering
78- # This allows each MCP server to receive only the headers it's configured to receive
7978 if headers :
80- event = Event (
81- author = "system" , actions = EventActions (state_delta = {HTTP_HEADERS_SESSION_KEY : dict (headers )})
82- )
83- await runner .session_service .append_event (session , event )
84- logger .debug ("Stored HTTP headers in session %s via state_delta" , session .id )
79+ state_delta : dict [str , object ] = {}
80+ for key , value in headers .items ():
81+ if key .lower () in self ._propagate_headers_lower :
82+ state_delta [f"{ HTTP_HEADERS_SESSION_KEY } .{ key .lower ()} " ] = value
83+
84+ if state_delta :
85+ event = Event (author = "system" , actions = EventActions (state_delta = state_delta ))
86+ await runner .session_service .append_event (session , event )
87+ logger .debug ("Stored %d HTTP headers in session %s via state_delta" , len (state_delta ), session .id )
8588
8689 return session
8790
@@ -92,12 +95,17 @@ def filter(self, record: logging.LogRecord) -> bool:
9295 return record .getMessage ().find (AGENT_CARD_WELL_KNOWN_PATH ) == - 1
9396
9497
95- async def create_a2a_app (agent : BaseAgent , rpc_url : str ) -> A2AStarletteApplication :
98+ async def create_a2a_app (
99+ agent : BaseAgent , rpc_url : str , propagate_headers : set [str ] | None = None
100+ ) -> A2AStarletteApplication :
96101 """Create an A2A Starlette application from an ADK agent.
97102
98103 Args:
99104 agent: The ADK agent to convert
100105 rpc_url: The URL where the agent will be available for A2A communication
106+ propagate_headers: Union of all header names that any MCP tool is configured to
107+ propagate. Only these headers will be captured from incoming requests and
108+ stored in the session state.
101109 Returns:
102110 An A2AStarletteApplication instance
103111 """
@@ -122,6 +130,7 @@ async def create_runner() -> Runner:
122130 # Use custom executor that captures HTTP headers and stores in session
123131 agent_executor = HeaderCapturingA2aAgentExecutor (
124132 runner = create_runner ,
133+ propagate_headers = propagate_headers or set (),
125134 )
126135
127136 request_handler = DefaultRequestHandler (agent_executor = agent_executor , task_store = task_store )
@@ -175,14 +184,15 @@ def to_a2a(
175184 """
176185
177186 agent_factory = agent_factory or AgentFactory ()
187+ all_propagate_headers = {h for tool in (tools or []) for h in tool .propagate_headers }
178188
179189 async def a2a_app_creator () -> A2AStarletteApplication :
180190 configured_agent = await agent_factory .load_agent (
181191 agent = agent ,
182192 sub_agents = sub_agents or [],
183193 tools = tools or [],
184194 )
185- return await create_a2a_app (configured_agent , rpc_url )
195+ return await create_a2a_app (configured_agent , rpc_url , propagate_headers = all_propagate_headers )
186196
187197 return to_starlette (a2a_app_creator )
188198
0 commit comments