Skip to content

Commit e8c0cff

Browse files
syn-zhuclaude
andcommitted
feat(agentsts): move STS token exchange to before_tool_callback
Move STS token exchange from before_run_callback (eager, once per invocation) to before_tool_callback (lazy, per MCP tool call). This avoids the sync/async problem in header_provider and only performs the exchange when an MCP tool is actually invoked. - before_run_callback now only extracts and stores the subject token - before_tool_callback exchanges on first McpTool call per session - Non-MCP tools (memory, AgentTool, etc.) are skipped - Cached tokens are reused for subsequent MCP calls in same session - Sets the stage for per-audience token exchange Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 97a5014 commit e8c0cff

2 files changed

Lines changed: 145 additions & 33 deletions

File tree

python/packages/agentsts-adk/src/agentsts/adk/_base.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from google.adk.sessions.session import Session
1515
from google.adk.tools.base_tool import BaseTool
1616
from google.adk.tools.mcp_tool import MCPTool
17+
from google.adk.tools.mcp_tool.mcp_tool import McpTool
1718
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
1819
from google.adk.tools.tool_context import ToolContext
1920
from typing_extensions import override
@@ -40,7 +41,18 @@ def __init__(
4041

4142

4243
class ADKTokenPropagationPlugin(BasePlugin):
43-
"""Plugin for propagating STS tokens to ADK tools."""
44+
"""Plugin for propagating STS tokens to ADK tools.
45+
46+
Token exchange lifecycle:
47+
1. before_run_callback: extracts the subject token from request headers
48+
and stores it for the duration of the invocation.
49+
2. before_tool_callback: when an MCP tool is about to be called and STS
50+
is configured, exchanges the subject token (async) and caches the
51+
access token so header_provider can read it.
52+
3. header_provider (sync): returns cached access token as Authorization
53+
header -- called by McpToolset/McpTool during MCP session setup.
54+
4. after_run_callback: cleans up all cached state.
55+
"""
4456

4557
def __init__(self, sts_integration: Optional[STSIntegrationBase] = None):
4658
"""Initialize the token propagation plugin.
@@ -51,6 +63,7 @@ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None):
5163
super().__init__("ADKTokenPropagationPlugin")
5264
self.sts_integration = sts_integration
5365
self.token_cache: Dict[str, str] = {}
66+
self._subject_tokens: Dict[str, str] = {}
5467

5568
def add_to_agent(self, agent: BaseAgent):
5669
"""
@@ -70,7 +83,6 @@ def add_to_agent(self, agent: BaseAgent):
7083
logger.debug("Updated tool connection params to include access token from STS server")
7184

7285
def header_provider(self, readonly_context: Optional[ReadonlyContext]) -> Dict[str, str]:
73-
# access save token
7486
access_token = self.token_cache.get(self.cache_key(readonly_context._invocation_context), "")
7587
if not access_token:
7688
return {}
@@ -85,25 +97,58 @@ async def before_run_callback(
8597
*,
8698
invocation_context: InvocationContext,
8799
) -> Optional[dict]:
88-
"""Propagate token to model before execution."""
100+
"""Extract and store the subject token for later exchange."""
89101
headers = invocation_context.session.state.get(HEADERS_KEY, None)
90102
subject_token = _extract_jwt_from_headers(headers)
91103
if not subject_token:
92104
logger.debug("No subject token found in headers for token propagation")
93105
return None
94-
if self.sts_integration:
95-
try:
96-
subject_token = await self.sts_integration.exchange_token(
97-
subject_token=subject_token,
98-
subject_token_type=TokenType.JWT,
99-
actor_token=self.sts_integration._actor_token,
100-
actor_token_type=TokenType.JWT if self.sts_integration._actor_token else None,
101-
)
102-
except Exception as e:
103-
logger.warning(f"STS token exchange failed: {e}")
104-
return None
105-
# no sts, just propagate the subject token upstream
106-
self.token_cache[self.cache_key(invocation_context)] = subject_token
106+
key = self.cache_key(invocation_context)
107+
self._subject_tokens[key] = subject_token
108+
if not self.sts_integration:
109+
# No STS -- propagate the subject token directly so
110+
# header_provider can return it on the first tool call.
111+
self.token_cache[key] = subject_token
112+
return None
113+
114+
@override
115+
async def before_tool_callback(
116+
self,
117+
*,
118+
tool: BaseTool,
119+
tool_args: dict[str, Any],
120+
tool_context: ToolContext,
121+
) -> Optional[dict]:
122+
"""Exchange the subject token via STS before each MCP tool call."""
123+
if not self.sts_integration:
124+
return None
125+
# Only exchange tokens for MCP tool calls. Other tool types
126+
# (memory tools, AgentTool, etc.) don't use header_provider and
127+
# have their own auth mechanisms, so exchanging here would be a
128+
# wasted HTTP round-trip to the STS.
129+
if not isinstance(tool, McpTool):
130+
return None
131+
132+
key = self.cache_key(tool_context._invocation_context)
133+
# Already exchanged for this session
134+
if key in self.token_cache:
135+
return None
136+
137+
subject_token = self._subject_tokens.get(key)
138+
if not subject_token:
139+
return None
140+
141+
try:
142+
access_token = await self.sts_integration.exchange_token(
143+
subject_token=subject_token,
144+
subject_token_type=TokenType.JWT,
145+
actor_token=self.sts_integration._actor_token,
146+
actor_token_type=TokenType.JWT if self.sts_integration._actor_token else None,
147+
)
148+
self.token_cache[key] = access_token
149+
except Exception as e:
150+
logger.warning(f"STS token exchange failed: {e}")
151+
107152
return None
108153

109154
def cache_key(self, invocation_context: InvocationContext) -> str:
@@ -116,8 +161,9 @@ async def after_run_callback(
116161
*,
117162
invocation_context: InvocationContext,
118163
) -> Optional[dict]:
119-
# delete token after run
120-
self.token_cache.pop(self.cache_key(invocation_context), None)
164+
key = self.cache_key(invocation_context)
165+
self.token_cache.pop(key, None)
166+
self._subject_tokens.pop(key, None)
121167
return None
122168

123169

python/packages/agentsts-adk/tests/test_adk_integration.py

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from google.adk.agents import LlmAgent
7+
from google.adk.tools.mcp_tool.mcp_tool import McpTool
78
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
89

910
from agentsts.adk import ADKSTSIntegration, ADKTokenPropagationPlugin
@@ -31,6 +32,16 @@ def _make_readonly_context(self, invocation_context):
3132
readonly_context._invocation_context = invocation_context
3233
return readonly_context
3334

35+
def _make_tool_context(self, invocation_context):
36+
tool_context = Mock()
37+
tool_context._invocation_context = invocation_context
38+
return tool_context
39+
40+
def _make_mcp_tool(self):
41+
tool = Mock(spec=McpTool)
42+
tool.name = "test_tool"
43+
return tool
44+
3445
def test_init(self):
3546
mock_sts_integration = Mock()
3647
plugin = ADKTokenPropagationPlugin(mock_sts_integration)
@@ -77,49 +88,102 @@ async def test_downstream_token_propagation_without_sts(self):
7788

7889
@pytest.mark.asyncio
7990
async def test_sts_token_exchange_success(self):
80-
"""Case: STS integration exchanges token -> access token cached and returned by header provider."""
91+
"""Case: STS integration -- before_run stores subject token, before_tool_callback exchanges it."""
8192
sts = Mock(spec=ADKSTSIntegration)
8293
sts._actor_token = "actor-token"
8394
sts.exchange_token = AsyncMock(return_value="access-token-XYZ")
8495
plugin = ADKTokenPropagationPlugin(sts)
8596
ic = self._make_invocation_context("sess-3", headers={"Authorization": "Bearer original-subject"})
86-
with patch("agentsts.adk._base.logger") as mock_logger:
87-
result = await plugin.before_run_callback(invocation_context=ic)
88-
assert result is None
89-
sts.exchange_token.assert_called_once_with(
90-
subject_token="original-subject",
91-
subject_token_type=TokenType.JWT,
92-
actor_token="actor-token",
93-
actor_token_type=TokenType.JWT,
94-
)
95-
# optional debug log length check
96-
mock_logger.debug.assert_called() # at least one debug log
97+
98+
# before_run_callback should store the subject token but NOT exchange
99+
result = await plugin.before_run_callback(invocation_context=ic)
100+
assert result is None
101+
sts.exchange_token.assert_not_called()
102+
assert "sess-3" not in plugin.token_cache
103+
assert plugin._subject_tokens["sess-3"] == "original-subject"
104+
105+
# before_tool_callback should exchange on first MCP tool call
106+
tool = self._make_mcp_tool()
107+
tc = self._make_tool_context(ic)
108+
result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc)
109+
assert result is None
110+
sts.exchange_token.assert_called_once_with(
111+
subject_token="original-subject",
112+
subject_token_type=TokenType.JWT,
113+
actor_token="actor-token",
114+
actor_token_type=TokenType.JWT,
115+
)
97116
assert plugin.token_cache["sess-3"] == "access-token-XYZ"
98117

118+
# header_provider should return the exchanged token
99119
ro_ctx = self._make_readonly_context(ic)
100120
headers = plugin.header_provider(ro_ctx)
101121
assert headers == {"Authorization": "Bearer access-token-XYZ"}
102122

123+
# second tool call should not exchange again (cached)
124+
sts.exchange_token.reset_mock()
125+
result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc)
126+
assert result is None
127+
sts.exchange_token.assert_not_called()
128+
103129
await plugin.after_run_callback(invocation_context=ic)
104130
assert "sess-3" not in plugin.token_cache
131+
assert "sess-3" not in plugin._subject_tokens
105132

106133
@pytest.mark.asyncio
107134
async def test_sts_token_exchange_failure(self):
108-
"""Case: STS exchange raises -> no cache entry, graceful warning."""
135+
"""Case: STS exchange raises in before_tool_callback -> no cache entry, graceful warning."""
109136
sts = Mock(spec=ADKSTSIntegration)
110137
sts._actor_token = "actor-token"
111138
sts.exchange_token = AsyncMock(side_effect=Exception("boom"))
112139
plugin = ADKTokenPropagationPlugin(sts)
113140
ic = self._make_invocation_context("sess-4", headers={"Authorization": "Bearer original-subject"})
141+
142+
await plugin.before_run_callback(invocation_context=ic)
143+
assert plugin._subject_tokens["sess-4"] == "original-subject"
144+
145+
tool = self._make_mcp_tool()
146+
tc = self._make_tool_context(ic)
114147
with patch("agentsts.adk._base.logger") as mock_logger:
115-
result = await plugin.before_run_callback(invocation_context=ic)
148+
result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc)
116149
assert result is None
117150
mock_logger.warning.assert_called_once()
118151
assert "sess-4" not in plugin.token_cache
152+
119153
# header provider should yield empty dict
120154
ro_ctx = self._make_readonly_context(ic)
121155
assert plugin.header_provider(ro_ctx) == {}
122156

157+
@pytest.mark.asyncio
158+
async def test_before_tool_callback_skips_non_mcp_tools(self):
159+
"""Case: before_tool_callback ignores non-MCP tools."""
160+
sts = Mock(spec=ADKSTSIntegration)
161+
sts._actor_token = "actor-token"
162+
sts.exchange_token = AsyncMock(return_value="access-token")
163+
plugin = ADKTokenPropagationPlugin(sts)
164+
ic = self._make_invocation_context("sess-7", headers={"Authorization": "Bearer subj"})
165+
await plugin.before_run_callback(invocation_context=ic)
166+
167+
non_mcp_tool = Mock() # not a McpTool
168+
tc = self._make_tool_context(ic)
169+
result = await plugin.before_tool_callback(tool=non_mcp_tool, tool_args={}, tool_context=tc)
170+
assert result is None
171+
sts.exchange_token.assert_not_called()
172+
173+
@pytest.mark.asyncio
174+
async def test_before_tool_callback_no_sts(self):
175+
"""Case: before_tool_callback is a no-op without STS integration."""
176+
plugin = ADKTokenPropagationPlugin(sts_integration=None)
177+
ic = self._make_invocation_context("sess-8", headers={"Authorization": "Bearer subj"})
178+
await plugin.before_run_callback(invocation_context=ic)
179+
180+
tool = self._make_mcp_tool()
181+
tc = self._make_tool_context(ic)
182+
result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc)
183+
assert result is None
184+
# token_cache should still have the subject token from before_run
185+
assert plugin.token_cache["sess-8"] == "subj"
186+
123187
def test_header_provider_no_entry(self):
124188
"""Case: header_provider called with no cached token -> returns empty dict."""
125189
plugin = ADKTokenPropagationPlugin()
@@ -131,13 +195,15 @@ def test_header_provider_no_entry(self):
131195

132196
@pytest.mark.asyncio
133197
async def test_after_run_callback_removes_token(self):
134-
"""Case: after_run_callback removes cached token."""
198+
"""Case: after_run_callback removes cached token and subject token."""
135199
plugin = ADKTokenPropagationPlugin()
136200
ic = self._make_invocation_context("sess-6", headers={"Authorization": "Bearer AAA"})
137201
await plugin.before_run_callback(invocation_context=ic)
138202
assert "sess-6" in plugin.token_cache
203+
assert "sess-6" in plugin._subject_tokens
139204
await plugin.after_run_callback(invocation_context=ic)
140205
assert "sess-6" not in plugin.token_cache
206+
assert "sess-6" not in plugin._subject_tokens
141207

142208
def test_extract_jwt_from_headers_success(self):
143209
"""Test successful JWT extraction from headers."""

0 commit comments

Comments
 (0)