diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index bedd93f24..29b4307bf 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -7,13 +7,13 @@ import logging from datetime import timedelta -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from mcp.types import Tool as MCPTool from typing_extensions import override -from ...types._events import ToolResultEvent -from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse +from ...types._events import ToolResultEvent, ToolStreamEvent +from ...types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse if TYPE_CHECKING: from .mcp_client import MCPClient @@ -110,10 +110,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"]) - result = await self.mcp_client.call_tool_async( + async for event in self.mcp_client.call_tool_stream( tool_use_id=tool_use["toolUseId"], name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], read_timeout_seconds=self.timeout, - ) - yield ToolResultEvent(result) + ): + if isinstance(event, dict) and event.get("status") in ["success", "error"]: + # It's a MCPToolResult + yield ToolResultEvent(cast(ToolResult, event)) + else: + # It's a streaming chunk + yield ToolStreamEvent(tool_use, event) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index bb5dca19c..9af20cbb3 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,11 +16,11 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast import anyio from mcp import ClientSession, ListToolsResult -from mcp.client.session import ElicitationFnT +from mcp.client.session import ElicitationFnT, ProgressFnT from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import EmbeddedResource as MCPEmbeddedResource @@ -523,6 +523,77 @@ async def _call_tool_async() -> MCPCallToolResult: logger.exception("tool execution failed") return self._handle_tool_execution_error(tool_use_id, e) + async def call_tool_stream( + self, + tool_use_id: str, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + ) -> AsyncGenerator[Union[Any, MCPToolResult], None]: + """Asynchronously calls a tool on the MCP server with streaming support. + + This method calls the asynchronous call_tool method on the MCP session, + streaming progress updates as they arrive, and finally returning the full result. + + Args: + tool_use_id: Unique identifier for this tool use + name: Name of the tool to call + arguments: Optional arguments to pass to the tool + read_timeout_seconds: Optional timeout for the tool call + + Returns: + Any: Progress data chunks from the tool execution + MCPToolResult: The final result of the tool call + """ + self._log_debug_with_thread("streaming MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + queue: asyncio.Queue[Any] = asyncio.Queue() + loop = asyncio.get_running_loop() + + def progress_callback(progress_data: Any) -> None: + loop.call_soon_threadsafe(queue.put_nowait, progress_data) + + async def _call_tool_async() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds, progress_callback=cast(ProgressFnT, progress_callback) + ) + + task: asyncio.Future[MCPCallToolResult] | None = None + try: + # Start the tool call on the background thread + future = self._invoke_on_background_thread(_call_tool_async()) + task = asyncio.wrap_future(future) + + # Consume the queue and wait for task completion + while True: + # Wait for either new data or task completion + get_coro = asyncio.create_task(queue.get()) + done, _ = await asyncio.wait({task, get_coro}, return_when=asyncio.FIRST_COMPLETED) + + # Process queue items first + if get_coro in done: + yield get_coro.result() + else: + # If we didn't consume the queue item, cancel the get + get_coro.cancel() + + # Check if task is done + if task in done: + # Drain any remaining items in the queue + while not queue.empty(): + yield queue.get_nowait() + break + + # Yield the final result + call_tool_result: MCPCallToolResult = await task + yield self._handle_tool_result(tool_use_id, call_tool_result) + + except Exception as e: + logger.exception("tool execution failed") + yield self._handle_tool_execution_error(tool_use_id, e) + def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: """Create error ToolResult with consistent logging.""" return MCPToolResult( diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 81a2d9afb..462b8fff5 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -5,7 +5,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPAgentTool, MCPClient -from strands.types._events import ToolResultEvent +from strands.types._events import ToolResultEvent, ToolStreamEvent @pytest.fixture @@ -21,11 +21,16 @@ def mock_mcp_tool(): @pytest.fixture def mock_mcp_client(): mock_server = MagicMock(spec=MCPClient) - mock_server.call_tool_sync.return_value = { - "status": "success", - "toolUseId": "test-123", - "content": [{"text": "Success result"}], - } + + async def mock_stream(*args, **kwargs): + tool_use_id = kwargs.get("tool_use_id", "test-123") + yield { + "status": "success", + "toolUseId": tool_use_id, + "content": [{"text": "Success result"}], + } + + mock_server.call_tool_stream.side_effect = mock_stream return mock_server @@ -85,14 +90,49 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) - exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] - + expected_result = { + "status": "success", + "toolUseId": "test-123", + "content": [{"text": "Success result"}], + } + exp_events = [ToolResultEvent(expected_result)] assert tru_events == exp_events - mock_mcp_client.call_tool_async.assert_called_once_with( + mock_mcp_client.call_tool_stream.assert_called_once_with( tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None ) +@pytest.mark.asyncio +async def test_stream_yields_events(mcp_agent_tool, mock_mcp_client, alist): + tool_use = {"toolUseId": "test-stream", "name": "test_tool", "input": {}} + + async def mock_streaming_generator(*_, **__): + yield "chunk 1" + yield "chunk 2" + yield { + "status": "success", + "toolUseId": "test-stream", + "content": [{"text": "final"}], + } + + mock_mcp_client.call_tool_stream.side_effect = mock_streaming_generator + + tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) + + exp_events = [ + ToolStreamEvent(tool_use, "chunk 1"), + ToolStreamEvent(tool_use, "chunk 2"), + ToolResultEvent( + { + "status": "success", + "toolUseId": "test-stream", + "content": [{"text": "final"}], + } + ), + ] + assert tru_events == exp_events + + def test_timeout_initialization(mock_mcp_tool, mock_mcp_client): timeout = timedelta(seconds=30) agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) @@ -111,9 +151,13 @@ async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist): tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}} tru_events = await alist(agent_tool.stream(tool_use, {})) - exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] - + expected_result = { + "status": "success", + "toolUseId": "test-456", + "content": [{"text": "Success result"}], + } + exp_events = [ToolResultEvent(expected_result)] assert tru_events == exp_events - mock_mcp_client.call_tool_async.assert_called_once_with( + mock_mcp_client.call_tool_stream.assert_called_once_with( tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout ) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index ec77b48a2..b4d171c8d 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -723,3 +723,73 @@ async def test_handle_error_message_non_exception(): # This should not raise an exception await client._handle_error_message("normal message") + + +@pytest.mark.asyncio +async def test_call_tool_stream_success(mock_transport, mock_session): + """Test that call_tool_stream yields progress and returns final result.""" + import asyncio + + mock_content = MCPTextContent(type="text", text="Final Result") + mock_result = MCPCallToolResult(isError=False, content=[mock_content]) + + # Setup call_tool mock to invoke progress_callback + async def mock_call_tool(*args, progress_callback=None, **kwargs): + if progress_callback: + # Simulate streaming chunks + progress_callback("chunk 1") + await asyncio.sleep(0.01) + progress_callback("chunk 2") + await asyncio.sleep(0.01) + return mock_result + + mock_session.call_tool.side_effect = mock_call_tool + + with MCPClient(mock_transport["transport_callable"]) as client: + # We run in the test event loop. The mock_call_tool simulates the background thread + # by invoking the callback, which call_tool_stream bridges to the async generator via a queue. + + events = [] + final_result = None + + async for event in client.call_tool_stream(tool_use_id="test-stream-1", name="stream_tool", arguments={}): + events.append(event) + if isinstance(event, dict) and "toolUseId" in event: + final_result = event + + # Check yielded events + assert len(events) == 3 # chunk 1, chunk 2, final result + assert events[0] == "chunk 1" + assert events[1] == "chunk 2" + + # Check final result + assert final_result is not None + assert final_result["status"] == "success" + assert final_result["content"][0]["text"] == "Final Result" + + +@pytest.mark.asyncio +async def test_call_tool_stream_exception(mock_transport, mock_session): + """Test that call_tool_stream handles exceptions during execution.""" + + async def mock_call_tool_error(*args, **kwargs): + progress_callback = kwargs.get("progress_callback") + if progress_callback: + progress_callback("chunk 1") + raise ValueError("Stream error") + + mock_session.call_tool.side_effect = mock_call_tool_error + + with MCPClient(mock_transport["transport_callable"]) as client: + events = [] + + async for event in client.call_tool_stream(tool_use_id="test-stream-error", name="stream_tool", arguments={}): + events.append(event) + + # Should get the chunk first + assert events[0] == "chunk 1" + + # Last event should be error result + last_event = events[-1] + assert last_event["status"] == "error" + assert "Stream error" in last_event["content"][0]["text"]