44
55import pytest
66from google .adk .agents import LlmAgent
7+ from google .adk .tools .mcp_tool .mcp_tool import McpTool
78from google .adk .tools .mcp_tool .mcp_toolset import MCPToolset
89
910from agentsts .adk import ADKSTSIntegration , ADKTokenPropagationPlugin
@@ -31,6 +32,16 @@ def _make_readonly_context(self, invocation_context):
3132 readonly_context ._invocation_context = invocation_context
3233 return readonly_context
3334
35+ def _make_tool_context (self , invocation_context ):
36+ tool_context = Mock ()
37+ tool_context ._invocation_context = invocation_context
38+ return tool_context
39+
40+ def _make_mcp_tool (self ):
41+ tool = Mock (spec = McpTool )
42+ tool .name = "test_tool"
43+ return tool
44+
3445 def test_init (self ):
3546 mock_sts_integration = Mock ()
3647 plugin = ADKTokenPropagationPlugin (mock_sts_integration )
@@ -77,49 +88,102 @@ async def test_downstream_token_propagation_without_sts(self):
7788
7889 @pytest .mark .asyncio
7990 async def test_sts_token_exchange_success (self ):
80- """Case: STS integration exchanges token -> access token cached and returned by header provider ."""
91+ """Case: STS integration -- before_run stores subject token, before_tool_callback exchanges it ."""
8192 sts = Mock (spec = ADKSTSIntegration )
8293 sts ._actor_token = "actor-token"
8394 sts .exchange_token = AsyncMock (return_value = "access-token-XYZ" )
8495 plugin = ADKTokenPropagationPlugin (sts )
8596 ic = self ._make_invocation_context ("sess-3" , headers = {"Authorization" : "Bearer original-subject" })
86- with patch ("agentsts.adk._base.logger" ) as mock_logger :
87- result = await plugin .before_run_callback (invocation_context = ic )
88- assert result is None
89- sts .exchange_token .assert_called_once_with (
90- subject_token = "original-subject" ,
91- subject_token_type = TokenType .JWT ,
92- actor_token = "actor-token" ,
93- actor_token_type = TokenType .JWT ,
94- )
95- # optional debug log length check
96- mock_logger .debug .assert_called () # at least one debug log
97+
98+ # before_run_callback should store the subject token but NOT exchange
99+ result = await plugin .before_run_callback (invocation_context = ic )
100+ assert result is None
101+ sts .exchange_token .assert_not_called ()
102+ assert "sess-3" not in plugin .token_cache
103+ assert plugin ._subject_tokens ["sess-3" ] == "original-subject"
104+
105+ # before_tool_callback should exchange on first MCP tool call
106+ tool = self ._make_mcp_tool ()
107+ tc = self ._make_tool_context (ic )
108+ result = await plugin .before_tool_callback (tool = tool , tool_args = {}, tool_context = tc )
109+ assert result is None
110+ sts .exchange_token .assert_called_once_with (
111+ subject_token = "original-subject" ,
112+ subject_token_type = TokenType .JWT ,
113+ actor_token = "actor-token" ,
114+ actor_token_type = TokenType .JWT ,
115+ )
97116 assert plugin .token_cache ["sess-3" ] == "access-token-XYZ"
98117
118+ # header_provider should return the exchanged token
99119 ro_ctx = self ._make_readonly_context (ic )
100120 headers = plugin .header_provider (ro_ctx )
101121 assert headers == {"Authorization" : "Bearer access-token-XYZ" }
102122
123+ # second tool call should not exchange again (cached)
124+ sts .exchange_token .reset_mock ()
125+ result = await plugin .before_tool_callback (tool = tool , tool_args = {}, tool_context = tc )
126+ assert result is None
127+ sts .exchange_token .assert_not_called ()
128+
103129 await plugin .after_run_callback (invocation_context = ic )
104130 assert "sess-3" not in plugin .token_cache
131+ assert "sess-3" not in plugin ._subject_tokens
105132
106133 @pytest .mark .asyncio
107134 async def test_sts_token_exchange_failure (self ):
108- """Case: STS exchange raises -> no cache entry, graceful warning."""
135+ """Case: STS exchange raises in before_tool_callback -> no cache entry, graceful warning."""
109136 sts = Mock (spec = ADKSTSIntegration )
110137 sts ._actor_token = "actor-token"
111138 sts .exchange_token = AsyncMock (side_effect = Exception ("boom" ))
112139 plugin = ADKTokenPropagationPlugin (sts )
113140 ic = self ._make_invocation_context ("sess-4" , headers = {"Authorization" : "Bearer original-subject" })
141+
142+ await plugin .before_run_callback (invocation_context = ic )
143+ assert plugin ._subject_tokens ["sess-4" ] == "original-subject"
144+
145+ tool = self ._make_mcp_tool ()
146+ tc = self ._make_tool_context (ic )
114147 with patch ("agentsts.adk._base.logger" ) as mock_logger :
115- result = await plugin .before_run_callback ( invocation_context = ic )
148+ result = await plugin .before_tool_callback ( tool = tool , tool_args = {}, tool_context = tc )
116149 assert result is None
117150 mock_logger .warning .assert_called_once ()
118151 assert "sess-4" not in plugin .token_cache
152+
119153 # header provider should yield empty dict
120154 ro_ctx = self ._make_readonly_context (ic )
121155 assert plugin .header_provider (ro_ctx ) == {}
122156
157+ @pytest .mark .asyncio
158+ async def test_before_tool_callback_skips_non_mcp_tools (self ):
159+ """Case: before_tool_callback ignores non-MCP tools."""
160+ sts = Mock (spec = ADKSTSIntegration )
161+ sts ._actor_token = "actor-token"
162+ sts .exchange_token = AsyncMock (return_value = "access-token" )
163+ plugin = ADKTokenPropagationPlugin (sts )
164+ ic = self ._make_invocation_context ("sess-7" , headers = {"Authorization" : "Bearer subj" })
165+ await plugin .before_run_callback (invocation_context = ic )
166+
167+ non_mcp_tool = Mock () # not a McpTool
168+ tc = self ._make_tool_context (ic )
169+ result = await plugin .before_tool_callback (tool = non_mcp_tool , tool_args = {}, tool_context = tc )
170+ assert result is None
171+ sts .exchange_token .assert_not_called ()
172+
173+ @pytest .mark .asyncio
174+ async def test_before_tool_callback_no_sts (self ):
175+ """Case: before_tool_callback is a no-op without STS integration."""
176+ plugin = ADKTokenPropagationPlugin (sts_integration = None )
177+ ic = self ._make_invocation_context ("sess-8" , headers = {"Authorization" : "Bearer subj" })
178+ await plugin .before_run_callback (invocation_context = ic )
179+
180+ tool = self ._make_mcp_tool ()
181+ tc = self ._make_tool_context (ic )
182+ result = await plugin .before_tool_callback (tool = tool , tool_args = {}, tool_context = tc )
183+ assert result is None
184+ # token_cache should still have the subject token from before_run
185+ assert plugin .token_cache ["sess-8" ] == "subj"
186+
123187 def test_header_provider_no_entry (self ):
124188 """Case: header_provider called with no cached token -> returns empty dict."""
125189 plugin = ADKTokenPropagationPlugin ()
@@ -131,13 +195,15 @@ def test_header_provider_no_entry(self):
131195
132196 @pytest .mark .asyncio
133197 async def test_after_run_callback_removes_token (self ):
134- """Case: after_run_callback removes cached token."""
198+ """Case: after_run_callback removes cached token and subject token ."""
135199 plugin = ADKTokenPropagationPlugin ()
136200 ic = self ._make_invocation_context ("sess-6" , headers = {"Authorization" : "Bearer AAA" })
137201 await plugin .before_run_callback (invocation_context = ic )
138202 assert "sess-6" in plugin .token_cache
203+ assert "sess-6" in plugin ._subject_tokens
139204 await plugin .after_run_callback (invocation_context = ic )
140205 assert "sess-6" not in plugin .token_cache
206+ assert "sess-6" not in plugin ._subject_tokens
141207
142208 def test_extract_jwt_from_headers_success (self ):
143209 """Test successful JWT extraction from headers."""
0 commit comments