Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

import logging
from datetime import timedelta
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from mcp.types import Tool as MCPTool
from typing_extensions import override

from ...types._events import ToolResultEvent
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse
from ...types._events import ToolResultEvent, ToolStreamEvent
from ...types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse

if TYPE_CHECKING:
from .mcp_client import MCPClient
Expand Down Expand Up @@ -110,10 +110,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
"""
logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"])

result = await self.mcp_client.call_tool_async(
async for event in self.mcp_client.call_tool_stream(
tool_use_id=tool_use["toolUseId"],
name=self.mcp_tool.name, # Use original MCP name for server communication
arguments=tool_use["input"],
read_timeout_seconds=self.timeout,
)
yield ToolResultEvent(result)
):
if isinstance(event, dict) and event.get("status") in ["success", "error"]:
# It's a MCPToolResult
yield ToolResultEvent(cast(ToolResult, event))
else:
# It's a streaming chunk
yield ToolStreamEvent(tool_use, event)
75 changes: 73 additions & 2 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from concurrent import futures
from datetime import timedelta
from types import TracebackType
from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast
from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast

import anyio
from mcp import ClientSession, ListToolsResult
from mcp.client.session import ElicitationFnT
from mcp.client.session import ElicitationFnT, ProgressFnT
from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import EmbeddedResource as MCPEmbeddedResource
Expand Down Expand Up @@ -523,6 +523,77 @@ async def _call_tool_async() -> MCPCallToolResult:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)

async def call_tool_stream(
self,
tool_use_id: str,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> AsyncGenerator[Union[Any, MCPToolResult], None]:
"""Asynchronously calls a tool on the MCP server with streaming support.

This method calls the asynchronous call_tool method on the MCP session,
streaming progress updates as they arrive, and finally returning the full result.

Args:
tool_use_id: Unique identifier for this tool use
name: Name of the tool to call
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call

Returns:
Any: Progress data chunks from the tool execution
MCPToolResult: The final result of the tool call
"""
self._log_debug_with_thread("streaming MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

queue: asyncio.Queue[Any] = asyncio.Queue()
loop = asyncio.get_running_loop()

def progress_callback(progress_data: Any) -> None:
loop.call_soon_threadsafe(queue.put_nowait, progress_data)

async def _call_tool_async() -> MCPCallToolResult:
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds, progress_callback=cast(ProgressFnT, progress_callback)
)

task: asyncio.Future[MCPCallToolResult] | None = None
try:
# Start the tool call on the background thread
future = self._invoke_on_background_thread(_call_tool_async())
task = asyncio.wrap_future(future)

# Consume the queue and wait for task completion
while True:
# Wait for either new data or task completion
get_coro = asyncio.create_task(queue.get())
done, _ = await asyncio.wait({task, get_coro}, return_when=asyncio.FIRST_COMPLETED)

# Process queue items first
if get_coro in done:
yield get_coro.result()
else:
# If we didn't consume the queue item, cancel the get
get_coro.cancel()

# Check if task is done
if task in done:
# Drain any remaining items in the queue
while not queue.empty():
yield queue.get_nowait()
break

# Yield the final result
call_tool_result: MCPCallToolResult = await task
yield self._handle_tool_result(tool_use_id, call_tool_result)

except Exception as e:
logger.exception("tool execution failed")
yield self._handle_tool_execution_error(tool_use_id, e)

def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult:
"""Create error ToolResult with consistent logging."""
return MCPToolResult(
Expand Down
68 changes: 56 additions & 12 deletions tests/strands/tools/mcp/test_mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mcp.types import Tool as MCPTool

from strands.tools.mcp import MCPAgentTool, MCPClient
from strands.types._events import ToolResultEvent
from strands.types._events import ToolResultEvent, ToolStreamEvent


@pytest.fixture
Expand All @@ -21,11 +21,16 @@ def mock_mcp_tool():
@pytest.fixture
def mock_mcp_client():
mock_server = MagicMock(spec=MCPClient)
mock_server.call_tool_sync.return_value = {
"status": "success",
"toolUseId": "test-123",
"content": [{"text": "Success result"}],
}

async def mock_stream(*args, **kwargs):
tool_use_id = kwargs.get("tool_use_id", "test-123")
yield {
"status": "success",
"toolUseId": tool_use_id,
"content": [{"text": "Success result"}],
}

mock_server.call_tool_stream.side_effect = mock_stream
return mock_server


Expand Down Expand Up @@ -85,14 +90,49 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]

expected_result = {
"status": "success",
"toolUseId": "test-123",
"content": [{"text": "Success result"}],
}
exp_events = [ToolResultEvent(expected_result)]
assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
mock_mcp_client.call_tool_stream.assert_called_once_with(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None
)


@pytest.mark.asyncio
async def test_stream_yields_events(mcp_agent_tool, mock_mcp_client, alist):
tool_use = {"toolUseId": "test-stream", "name": "test_tool", "input": {}}

async def mock_streaming_generator(*_, **__):
yield "chunk 1"
yield "chunk 2"
yield {
"status": "success",
"toolUseId": "test-stream",
"content": [{"text": "final"}],
}

mock_mcp_client.call_tool_stream.side_effect = mock_streaming_generator

tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))

exp_events = [
ToolStreamEvent(tool_use, "chunk 1"),
ToolStreamEvent(tool_use, "chunk 2"),
ToolResultEvent(
{
"status": "success",
"toolUseId": "test-stream",
"content": [{"text": "final"}],
}
),
]
assert tru_events == exp_events


def test_timeout_initialization(mock_mcp_tool, mock_mcp_client):
timeout = timedelta(seconds=30)
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
Expand All @@ -111,9 +151,13 @@ async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist):
tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(agent_tool.stream(tool_use, {}))
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]

expected_result = {
"status": "success",
"toolUseId": "test-456",
"content": [{"text": "Success result"}],
}
exp_events = [ToolResultEvent(expected_result)]
assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
mock_mcp_client.call_tool_stream.assert_called_once_with(
tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout
)
70 changes: 70 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,73 @@ async def test_handle_error_message_non_exception():

# This should not raise an exception
await client._handle_error_message("normal message")


@pytest.mark.asyncio
async def test_call_tool_stream_success(mock_transport, mock_session):
"""Test that call_tool_stream yields progress and returns final result."""
import asyncio

mock_content = MCPTextContent(type="text", text="Final Result")
mock_result = MCPCallToolResult(isError=False, content=[mock_content])

# Setup call_tool mock to invoke progress_callback
async def mock_call_tool(*args, progress_callback=None, **kwargs):
if progress_callback:
# Simulate streaming chunks
progress_callback("chunk 1")
await asyncio.sleep(0.01)
progress_callback("chunk 2")
await asyncio.sleep(0.01)
return mock_result

mock_session.call_tool.side_effect = mock_call_tool

with MCPClient(mock_transport["transport_callable"]) as client:
# We run in the test event loop. The mock_call_tool simulates the background thread
# by invoking the callback, which call_tool_stream bridges to the async generator via a queue.

events = []
final_result = None

async for event in client.call_tool_stream(tool_use_id="test-stream-1", name="stream_tool", arguments={}):
events.append(event)
if isinstance(event, dict) and "toolUseId" in event:
final_result = event

# Check yielded events
assert len(events) == 3 # chunk 1, chunk 2, final result
assert events[0] == "chunk 1"
assert events[1] == "chunk 2"

# Check final result
assert final_result is not None
assert final_result["status"] == "success"
assert final_result["content"][0]["text"] == "Final Result"


@pytest.mark.asyncio
async def test_call_tool_stream_exception(mock_transport, mock_session):
"""Test that call_tool_stream handles exceptions during execution."""

async def mock_call_tool_error(*args, **kwargs):
progress_callback = kwargs.get("progress_callback")
if progress_callback:
progress_callback("chunk 1")
raise ValueError("Stream error")

mock_session.call_tool.side_effect = mock_call_tool_error

with MCPClient(mock_transport["transport_callable"]) as client:
events = []

async for event in client.call_tool_stream(tool_use_id="test-stream-error", name="stream_tool", arguments={}):
events.append(event)

# Should get the chunk first
assert events[0] == "chunk 1"

# Last event should be error result
last_event = events[-1]
assert last_event["status"] == "error"
assert "Stream error" in last_event["content"][0]["text"]