diff --git a/python/packages/kagent-adk/src/kagent/adk/_constants.py b/python/packages/kagent-adk/src/kagent/adk/_constants.py new file mode 100644 index 000000000..f18b6cf4c --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/_constants.py @@ -0,0 +1,3 @@ +# Header used for Gateway API proxy routing — the kagent proxy reads this +# to determine which upstream MCP server should receive the request. +PROXY_HOST_HEADER = "x-kagent-host" diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_capability_tools.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_capability_tools.py new file mode 100644 index 000000000..953671868 --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_capability_tools.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import base64 +import json +import logging +from typing import TYPE_CHECKING, Any + +from google.adk.tools import BaseTool +from google.genai import types +from typing_extensions import override + +if TYPE_CHECKING: + from google.adk.models.llm_request import LlmRequest + from google.adk.tools.tool_context import ToolContext + +logger = logging.getLogger("kagent_adk." + __name__) + + +class _BaseKAgentMcpLoaderTool(BaseTool): + def __init__(self, *, mcp_toolset: Any, name: str, description: str, server_label: str): + super().__init__(name=name, description=description) + self._mcp_toolset = mcp_toolset + self._server_label = server_label + + def _latest_function_responses(self, llm_request: "LlmRequest") -> list[dict[str, Any]]: + if not llm_request.contents: + return [] + + # Search backwards — other helpers may have appended content after + # the function response block during the same process_llm_request pass. + for content in reversed(llm_request.contents): + if not content.parts: + continue + + matching_responses: list[dict[str, Any]] = [] + has_any_function_response = False + for part in content.parts: + function_response = part.function_response + if function_response is None: + continue + has_any_function_response = True + if function_response.name == self.name: + matching_responses.append(function_response.response or {}) + + if has_any_function_response: + return matching_responses + + return [] + + def _block_to_part(self, block: dict[str, Any], fallback_name: str) -> types.Part: + block_type = block.get("type") + if block.get("text") is not None and block_type in {None, "text"}: + return types.Part.from_text(text=block["text"]) + + if block.get("blob") is not None and block_type is None: + return self._binary_part_from_base64( + payload=block["blob"], + mime_type=block.get("mimeType") or "application/octet-stream", + fallback_name=fallback_name, + ) + + if block_type in {"image", "audio"} and block.get("data") is not None: + return self._binary_part_from_base64( + payload=block["data"], + mime_type=block.get("mimeType") or "application/octet-stream", + fallback_name=fallback_name, + ) + + if block_type == "resource": + resource = block.get("resource") or {} + if resource.get("text") is not None: + return types.Part.from_text(text=resource["text"]) + if resource.get("blob") is not None: + return self._binary_part_from_base64( + payload=resource["blob"], + mime_type=resource.get("mimeType") or "application/octet-stream", + fallback_name=fallback_name, + ) + return types.Part.from_text(text=f"[Resource content for {fallback_name} could not be rendered]") + + if block_type == "resource_link": + return types.Part.from_text(text=json.dumps(block, indent=2, sort_keys=True)) + + return types.Part.from_text(text=json.dumps(block, indent=2, sort_keys=True)) + + def _binary_part_from_base64(self, payload: str, mime_type: str, fallback_name: str) -> types.Part: + try: + return types.Part.from_bytes(data=base64.b64decode(payload), mime_type=mime_type) + except Exception: + return types.Part.from_text(text=f"[Binary content for {fallback_name} could not be decoded]") + + +class LoadKAgentMcpResourceTool(_BaseKAgentMcpLoaderTool): + def __init__(self, *, mcp_toolset: Any, name: str, server_label: str): + super().__init__( + mcp_toolset=mcp_toolset, + name=name, + server_label=server_label, + description=( + f"Loads named resources from the MCP server '{server_label}' into the current context. " + "Call this before answering questions that depend on MCP resources." + ), + ) + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "resource_names": types.Schema( + type=types.Type.ARRAY, + description="The MCP resource names to load into context.", + items=types.Schema(type=types.Type.STRING), + ) + }, + required=["resource_names"], + ), + ) + + @override + async def run_async(self, *, args: dict[str, Any], tool_context: "ToolContext") -> Any: + raw_resource_names = args.get("resource_names", []) + if isinstance(raw_resource_names, str): + raw_resource_names = [raw_resource_names] + if not isinstance(raw_resource_names, (list, tuple)): + raw_resource_names = [] + resource_names = [str(name) for name in raw_resource_names if name] + return { + "resource_names": resource_names, + "status": "Requested MCP resources have been staged into the next model turn.", + } + + @override + async def process_llm_request(self, *, tool_context: "ToolContext", llm_request: "LlmRequest") -> None: + await super().process_llm_request(tool_context=tool_context, llm_request=llm_request) + await self._append_resource_catalog(tool_context=tool_context, llm_request=llm_request) + await self._append_selected_resources(tool_context=tool_context, llm_request=llm_request) + + async def _append_resource_catalog(self, *, tool_context: "ToolContext", llm_request: "LlmRequest") -> None: + try: + resource_names = await self._mcp_toolset.list_resources(tool_context) + except Exception as error: + logger.warning("Failed to list MCP resources from %s: %s", self._server_label, error) + return + + if not resource_names: + return + + llm_request.append_instructions( + [ + ( + f"You have MCP resources available from server '{self._server_label}':\n" + f"{json.dumps(resource_names)}\n\n" + f"When the user asks about one of these resources, call `{self.name}` first " + "with the relevant resource name or names." + ) + ] + ) + + async def _append_selected_resources(self, *, tool_context: "ToolContext", llm_request: "LlmRequest") -> None: + responses = self._latest_function_responses(llm_request) + if not responses: + return + + for response in responses: + for resource_name in response.get("resource_names", []): + try: + contents = await self._mcp_toolset.read_resource(resource_name, tool_context) + except Exception as error: + logger.warning( + "Failed to read MCP resource '%s' from %s: %s", resource_name, self._server_label, error + ) + continue + + for content in contents: + llm_request.contents.append( + types.Content( + role="user", + parts=[ + types.Part.from_text( + text=f"MCP resource '{resource_name}' from server '{self._server_label}' is:" + ), + self._block_to_part(content, resource_name), + ], + ) + ) + + +class LoadKAgentMcpPromptTool(_BaseKAgentMcpLoaderTool): + def __init__(self, *, mcp_toolset: Any, name: str, server_label: str): + super().__init__( + mcp_toolset=mcp_toolset, + name=name, + server_label=server_label, + description=( + f"Loads a named prompt from the MCP server '{server_label}' into the current context. " + "Pass any required prompt arguments as string values." + ), + ) + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "prompt_name": types.Schema( + type=types.Type.STRING, + description="The MCP prompt name to load.", + ), + "arguments": types.Schema( + type=types.Type.OBJECT, + description="Optional string arguments for the MCP prompt template.", + ), + }, + required=["prompt_name"], + ), + ) + + @override + async def run_async(self, *, args: dict[str, Any], tool_context: "ToolContext") -> Any: + raw_arguments = args.get("arguments") or {} + if not isinstance(raw_arguments, dict): + raw_arguments = {} + + arguments = {str(key): str(value) for key, value in raw_arguments.items()} + prompt_name = str(args.get("prompt_name", "")).strip() + return { + "prompt_name": prompt_name, + "arguments": arguments, + "status": "Requested MCP prompt has been staged into the next model turn.", + } + + @override + async def process_llm_request(self, *, tool_context: "ToolContext", llm_request: "LlmRequest") -> None: + await super().process_llm_request(tool_context=tool_context, llm_request=llm_request) + await self._append_prompt_catalog(tool_context=tool_context, llm_request=llm_request) + await self._append_selected_prompt(tool_context=tool_context, llm_request=llm_request) + + async def _append_prompt_catalog(self, *, tool_context: "ToolContext", llm_request: "LlmRequest") -> None: + try: + prompt_info = await self._mcp_toolset.list_prompt_info(tool_context) + except Exception as error: + logger.warning("Failed to list MCP prompts from %s: %s", self._server_label, error) + return + + if not prompt_info: + return + + prompt_catalog = [] + for prompt in prompt_info: + prompt_catalog.append( + { + "name": prompt.get("name"), + "description": prompt.get("description"), + "arguments": [ + { + "name": argument.get("name"), + "description": argument.get("description"), + "required": argument.get("required"), + } + for argument in prompt.get("arguments", []) + ], + } + ) + + llm_request.append_instructions( + [ + ( + f"You have MCP prompts available from server '{self._server_label}':\n" + f"{json.dumps(prompt_catalog, indent=2)}\n\n" + f"When a prompt is relevant, call `{self.name}` with `prompt_name` and any required string arguments " + "before composing your final answer." + ) + ] + ) + + async def _append_selected_prompt(self, *, tool_context: "ToolContext", llm_request: "LlmRequest") -> None: + responses = self._latest_function_responses(llm_request) + if not responses: + return + + for response in responses: + prompt_name = str(response.get("prompt_name", "")).strip() + if not prompt_name: + continue + + raw_arguments = response.get("arguments") or {} + if not isinstance(raw_arguments, dict): + raw_arguments = {} + arguments = {str(key): str(value) for key, value in raw_arguments.items()} + + try: + prompt = await self._mcp_toolset.get_prompt(prompt_name, arguments, tool_context) + except Exception as error: + logger.warning("Failed to load MCP prompt '%s' from %s: %s", prompt_name, self._server_label, error) + continue + + for index, message in enumerate(prompt.get("messages", []), start=1): + content = message.get("content") or {} + role = _map_mcp_prompt_role(message.get("role")) + llm_request.contents.append( + types.Content( + role=role, + parts=[ + types.Part.from_text( + text=( + f"MCP prompt '{prompt_name}' from server '{self._server_label}' " + f"returned message {index} with role '{message.get('role') or 'user'}':" + ) + ), + self._block_to_part(content, f"{prompt_name}_{index}"), + ], + ) + ) + + +def _map_mcp_prompt_role(role: Any) -> str: + if role == "assistant": + return "model" + if role == "user": + return "user" + return "user" diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 26c4c6df7..2668c60d8 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -1,12 +1,18 @@ from __future__ import annotations import asyncio +import hashlib import logging -from typing import Optional +import re +from typing import Any, Optional +from urllib.parse import urlparse from google.adk.tools import BaseTool from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext +from ._constants import PROXY_HOST_HEADER +from ._mcp_capability_tools import LoadKAgentMcpPromptTool, LoadKAgentMcpResourceTool + logger = logging.getLogger("kagent_adk." + __name__) @@ -27,10 +33,13 @@ class KAgentMcpToolset(McpToolset): async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]: try: - return await super().get_tools(readonly_context) + tools = await super().get_tools(readonly_context) except asyncio.CancelledError as error: raise _enrich_cancelled_error(error) from error + tools.extend(await self._create_capability_tools(readonly_context)) + return tools + async def close(self) -> None: """Close MCP sessions and suppress known anyio cancel scope cleanup errors. @@ -55,6 +64,96 @@ async def close(self) -> None: raise raise + async def list_prompt_info(self, readonly_context: Optional[ReadonlyContext] = None) -> list[dict[str, Any]]: + """Return prompt metadata exposed by the MCP server.""" + + result = await self._execute_with_session( + lambda session: session.list_prompts(), + "Failed to list prompts from MCP server", + readonly_context, + ) + return [_model_dump(prompt) for prompt in result.prompts] + + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + readonly_context: Optional[ReadonlyContext] = None, + ) -> dict[str, Any]: + """Fetch prompt contents from the MCP server.""" + + result = await self._execute_with_session( + lambda session: session.get_prompt(name, arguments=arguments or None), + f"Failed to get prompt {name} from MCP server", + readonly_context, + ) + return _model_dump(result) + + async def read_resource( + self, + name: str, + readonly_context: Optional[ReadonlyContext] = None, + ) -> list[dict[str, Any]]: + """Fetch resource contents from the MCP server.""" + + resource_info = await self.get_resource_info(name, readonly_context) + if "uri" not in resource_info: + raise ValueError(f"Resource '{name}' has no URI.") + + result = await self._execute_with_session( + lambda session: session.read_resource(uri=resource_info["uri"]), + f"Failed to get resource {name} from MCP server", + readonly_context, + ) + return [_model_dump(content) for content in result.contents] + + async def _create_capability_tools(self, readonly_context: Optional[ReadonlyContext]) -> list[BaseTool]: + capability_tools: list[BaseTool] = [] + + try: + if await self._has_resources(readonly_context): + capability_tools.append( + LoadKAgentMcpResourceTool( + mcp_toolset=self, + name=self.resource_tool_name, + server_label=self.server_label, + ) + ) + except Exception as error: + logger.info("Skipping MCP resource helper tool: %s", error) + + try: + if await self._has_prompts(readonly_context): + capability_tools.append( + LoadKAgentMcpPromptTool( + mcp_toolset=self, + name=self.prompt_tool_name, + server_label=self.server_label, + ) + ) + except Exception as error: + logger.info("Skipping MCP prompt helper tool: %s", error) + + return capability_tools + + async def _has_resources(self, readonly_context: Optional[ReadonlyContext]) -> bool: + return bool(await self.list_resources(readonly_context)) + + async def _has_prompts(self, readonly_context: Optional[ReadonlyContext]) -> bool: + return bool(await self.list_prompt_info(readonly_context)) + + @property + def server_label(self) -> str: + return _server_identity(self._connection_params)[0] + + @property + def resource_tool_name(self) -> str: + return f"load_mcp_resource_{_server_identity(self._connection_params)[1]}" + + @property + def prompt_tool_name(self) -> str: + return f"load_mcp_prompt_{_server_identity(self._connection_params)[1]}" + def is_anyio_cross_task_cancel_scope_error(error: BaseException) -> bool: current: BaseException | None = error @@ -67,3 +166,34 @@ def is_anyio_cross_task_cancel_scope_error(error: BaseException) -> bool: return True current = current.__cause__ or current.__context__ return False + + +def _server_identity(connection_params: Any) -> tuple[str, str]: + parsed = urlparse(str(getattr(connection_params, "url", ""))) + path = parsed.path.rstrip("/") + headers = getattr(connection_params, "headers", None) or {} + proxy_host = headers.get(PROXY_HOST_HEADER) + if proxy_host: + label = f"{proxy_host}{path}" if path else str(proxy_host) + else: + host = parsed.netloc + if host and path: + label = f"{host}{path}" + else: + label = host or path or "mcp_server" + + slug_source = re.sub(r"[^a-zA-Z0-9]+", "_", label).strip("_").lower() or "mcp_server" + digest = hashlib.sha1(label.encode("utf-8")).hexdigest()[:8] + return label, f"{slug_source[:32]}_{digest}" + + +def _model_dump(value: Any) -> dict[str, Any]: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json", exclude_none=True) + if isinstance(value, dict): + return value + try: + return dict(value) + except (TypeError, ValueError): + logger.warning("Cannot convert %s to dict, returning string representation", type(value).__name__) + return {"raw": str(value)} diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 8dbab31fc..98c559986 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from kagent.adk._approval import make_approval_callback +from kagent.adk._constants import PROXY_HOST_HEADER from kagent.adk._mcp_toolset import KAgentMcpToolset from kagent.adk._remote_a2a_tool import KAgentRemoteA2AToolset from kagent.adk.models._litellm import KAgentLiteLlm @@ -25,9 +26,6 @@ logger = logging.getLogger(__name__) -# Proxy host header used for Gateway API routing when using a proxy -PROXY_HOST_HEADER = "x-kagent-host" - # Key used to store headers in session state HEADERS_STATE_KEY = "headers" diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_capability_tools.py b/python/packages/kagent-adk/tests/unittests/test_mcp_capability_tools.py new file mode 100644 index 000000000..cdf841019 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_capability_tools.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from google.adk.tools.mcp_tool import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from google.genai import types + +from kagent.adk._mcp_capability_tools import LoadKAgentMcpPromptTool, LoadKAgentMcpResourceTool +from kagent.adk._mcp_toolset import KAgentMcpToolset +from kagent.adk.types import AgentConfig, HttpMcpServerConfig, OpenAI + + +class FakeLlmRequest: + def __init__(self, responses: list[tuple[str, dict[str, Any]]]): + self.contents = [ + types.Content( + role="user", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=response_name, + response=response_payload, + ) + ) + for response_name, response_payload in responses + ], + ) + ] + self.instructions: list[str] = [] + self.appended_tools: list[Any] = [] + + def append_instructions(self, instructions: list[str]) -> None: + self.instructions.extend(instructions) + + def append_tools(self, tools: list[Any]) -> None: + self.appended_tools.extend(tools) + + +class StubPromptToolset: + async def list_prompt_info(self, readonly_context: Any = None) -> list[dict[str, Any]]: + return [ + { + "name": "incident_triage", + "description": "Guide incident triage", + "arguments": [{"name": "ticket_id", "description": "Incident identifier", "required": True}], + } + ] + + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + readonly_context: Any = None, + ) -> dict[str, Any]: + assert name == "incident_triage" + assert arguments == {"ticket_id": "INC-42"} + return { + "messages": [ + { + "role": "assistant", + "content": { + "type": "text", + "text": "Investigate the controller logs first.", + }, + } + ] + } + + +class StubResourceToolset: + async def list_resources(self, readonly_context: Any = None) -> list[str]: + return ["cluster_runbook"] + + async def read_resource(self, name: str, readonly_context: Any = None) -> list[dict[str, Any]]: + assert name == "cluster_runbook" + return [{"text": "Restart the controller deployment if reconciliation is stuck."}] + + +class StubCombinedToolset(StubPromptToolset, StubResourceToolset): + """Supports both prompts and resources for combined-call tests.""" + + pass + + +def _make_agent_config(url: str) -> AgentConfig: + return AgentConfig( + model=OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake"), + description="Test agent", + instruction="You are a test agent", + http_tools=[ + HttpMcpServerConfig( + params=StreamableHTTPConnectionParams(url=url, headers=None), + tools=["test-tool"], + ) + ], + ) + + +@pytest.mark.asyncio +async def test_kagent_mcp_toolset_adds_prompt_and_resource_helpers(monkeypatch): + async def _base_get_tools(self, readonly_context=None): + return [] + + async def _list_resources(self, readonly_context=None): + return ["cluster_runbook"] + + async def _list_prompt_info(self, readonly_context=None): + return [{"name": "incident_triage", "arguments": []}] + + monkeypatch.setattr(McpToolset, "get_tools", _base_get_tools) + monkeypatch.setattr(KAgentMcpToolset, "list_resources", _list_resources) + monkeypatch.setattr(KAgentMcpToolset, "list_prompt_info", _list_prompt_info) + + agent = _make_agent_config("http://tools.kagent:8080/mcp").to_agent("test_agent") + mcp_toolset = next(tool for tool in agent.tools if isinstance(tool, KAgentMcpToolset)) + + helper_tools = await mcp_toolset.get_tools() + helper_names = {tool.name for tool in helper_tools} + + assert mcp_toolset.resource_tool_name in helper_names + assert mcp_toolset.prompt_tool_name in helper_names + + +@pytest.mark.asyncio +async def test_kagent_mcp_toolset_skips_helpers_when_capabilities_are_missing(monkeypatch): + async def _base_get_tools(self, readonly_context=None): + return [] + + async def _list_resources(self, readonly_context=None): + return [] + + async def _list_prompt_info(self, readonly_context=None): + return [] + + monkeypatch.setattr(McpToolset, "get_tools", _base_get_tools) + monkeypatch.setattr(KAgentMcpToolset, "list_resources", _list_resources) + monkeypatch.setattr(KAgentMcpToolset, "list_prompt_info", _list_prompt_info) + + agent = _make_agent_config("http://tools.kagent:8080/mcp").to_agent("test_agent") + mcp_toolset = next(tool for tool in agent.tools if isinstance(tool, KAgentMcpToolset)) + + assert await mcp_toolset.get_tools() == [] + + +def test_kagent_mcp_toolset_generates_unique_helper_names_per_server(): + first_agent = _make_agent_config("https://gateway.example/mcp/team-a").to_agent("first_agent") + second_agent = _make_agent_config("https://gateway.example/mcp/team-b").to_agent("second_agent") + + first_toolset = next(tool for tool in first_agent.tools if isinstance(tool, KAgentMcpToolset)) + second_toolset = next(tool for tool in second_agent.tools if isinstance(tool, KAgentMcpToolset)) + + assert first_toolset.resource_tool_name != second_toolset.resource_tool_name + assert first_toolset.prompt_tool_name != second_toolset.prompt_tool_name + + +@pytest.mark.asyncio +async def test_prompt_loader_adds_prompt_catalog_and_contents(): + tool = LoadKAgentMcpPromptTool( + mcp_toolset=StubPromptToolset(), + name="load_mcp_prompt_incident", + server_label="incident-mcp", + ) + llm_request = FakeLlmRequest( + responses=[ + ("other_tool", {"status": "ignored"}), + (tool.name, {"prompt_name": "incident_triage", "arguments": {"ticket_id": "INC-42"}}), + ], + ) + + await tool.process_llm_request(tool_context=MagicMock(), llm_request=llm_request) + + assert llm_request.appended_tools == [tool] + assert any("incident_triage" in instruction for instruction in llm_request.instructions) + assert llm_request.contents[1].role == "model" + assert any( + part.text and "Investigate the controller logs first." in part.text + for content in llm_request.contents[1:] + for part in content.parts + if getattr(part, "text", None) + ) + + +@pytest.mark.asyncio +async def test_resource_loader_coerces_string_resource_names(): + tool = LoadKAgentMcpResourceTool( + mcp_toolset=StubResourceToolset(), + name="load_mcp_resource_cluster", + server_label="cluster-mcp", + ) + + result = await tool.run_async(args={"resource_names": "cluster_runbook"}, tool_context=MagicMock()) + assert result["resource_names"] == ["cluster_runbook"] + + result = await tool.run_async(args={"resource_names": 12345}, tool_context=MagicMock()) + assert result["resource_names"] == [] + + result = await tool.run_async(args={}, tool_context=MagicMock()) + assert result["resource_names"] == [] + + +@pytest.mark.asyncio +async def test_resource_loader_adds_resource_catalog_and_contents(): + tool = LoadKAgentMcpResourceTool( + mcp_toolset=StubResourceToolset(), + name="load_mcp_resource_cluster", + server_label="cluster-mcp", + ) + llm_request = FakeLlmRequest( + responses=[ + ("other_tool", {"status": "ignored"}), + (tool.name, {"resource_names": ["cluster_runbook"]}), + ], + ) + + await tool.process_llm_request(tool_context=MagicMock(), llm_request=llm_request) + + assert llm_request.appended_tools == [tool] + assert any("cluster_runbook" in instruction for instruction in llm_request.instructions) + assert any( + part.text and "Restart the controller deployment" in part.text + for content in llm_request.contents[1:] + for part in content.parts + if getattr(part, "text", None) + ) + + +@pytest.mark.asyncio +async def test_combined_resource_and_prompt_helpers_in_same_turn(): + """Both helpers called in one turn — the second must still find its function response.""" + toolset = StubCombinedToolset() + + resource_tool = LoadKAgentMcpResourceTool( + mcp_toolset=toolset, + name="load_mcp_resource_cluster", + server_label="cluster-mcp", + ) + prompt_tool = LoadKAgentMcpPromptTool( + mcp_toolset=toolset, + name="load_mcp_prompt_incident", + server_label="cluster-mcp", + ) + + llm_request = FakeLlmRequest( + responses=[ + (resource_tool.name, {"resource_names": ["cluster_runbook"]}), + (prompt_tool.name, {"prompt_name": "incident_triage", "arguments": {"ticket_id": "INC-42"}}), + ], + ) + + # Resource helper runs first and appends content to llm_request.contents + await resource_tool.process_llm_request(tool_context=MagicMock(), llm_request=llm_request) + # Prompt helper runs second — must still find its function response + await prompt_tool.process_llm_request(tool_context=MagicMock(), llm_request=llm_request) + + # Verify resource content was loaded + assert any( + part.text and "Restart the controller deployment" in part.text + for content in llm_request.contents[1:] + for part in content.parts + if getattr(part, "text", None) + ) + + # Verify prompt content was loaded (this would fail before the fix) + assert any( + part.text and "Investigate the controller logs first." in part.text + for content in llm_request.contents[1:] + for part in content.parts + if getattr(part, "text", None) + )