Skip to content

Commit 18559fe

Browse files
committed
Move allow_input_required overload+gate down to ClientSession.call_tool
The gate now lives once on ClientSession (mechanics layer); Client and ClientSessionGroup are pure passthroughs that forward allow_input_required. Third 'bool' overload on ClientSession.call_tool lets the passthrough impls type-check. Reverts the isinstance narrowing in examples/ and tests/shared/; default return is CallToolResult everywhere, so the change is additive.
1 parent 0b4d8a9 commit 18559fe

15 files changed

Lines changed: 76 additions & 64 deletions

File tree

README.v2.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,7 +2166,6 @@ async def run():
21662166

21672167
# Call a tool (add tool from mcpserver_quickstart)
21682168
result = await session.call_tool("add", arguments={"a": 5, "b": 3})
2169-
assert isinstance(result, types.CallToolResult)
21702169
result_unstructured = result.content[0]
21712170
if isinstance(result_unstructured, types.TextContent):
21722171
print(f"Tool result: {result_unstructured.text}")
@@ -2432,22 +2431,19 @@ async def parse_tool_results():
24322431

24332432
# Example 1: Parsing text content
24342433
result = await session.call_tool("get_data", {"format": "text"})
2435-
assert isinstance(result, types.CallToolResult)
24362434
for content in result.content:
24372435
if isinstance(content, types.TextContent):
24382436
print(f"Text: {content.text}")
24392437

24402438
# Example 2: Parsing structured content from JSON tools
24412439
result = await session.call_tool("get_user", {"id": "123"})
2442-
assert isinstance(result, types.CallToolResult)
2443-
if result.structured_content:
2440+
if hasattr(result, "structured_content") and result.structured_content:
24442441
# Access structured data directly
24452442
user_data = result.structured_content
24462443
print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}")
24472444

24482445
# Example 3: Parsing embedded resources
24492446
result = await session.call_tool("read_config", {})
2450-
assert isinstance(result, types.CallToolResult)
24512447
for content in result.content:
24522448
if isinstance(content, types.EmbeddedResource):
24532449
resource = content.resource
@@ -2458,14 +2454,12 @@ async def parse_tool_results():
24582454

24592455
# Example 4: Parsing image content
24602456
result = await session.call_tool("generate_chart", {"data": [1, 2, 3]})
2461-
assert isinstance(result, types.CallToolResult)
24622457
for content in result.content:
24632458
if isinstance(content, types.ImageContent):
24642459
print(f"Image ({content.mime_type}): {len(content.data)} bytes")
24652460

24662461
# Example 5: Handling errors
24672462
result = await session.call_tool("failing_tool", {})
2468-
assert isinstance(result, types.CallToolResult)
24692463
if result.is_error:
24702464
print("Tool execution failed!")
24712465
for content in result.content:

docs/migration.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,9 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer`
364364

365365
`Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it.
366366

367-
### `ClientSession.call_tool` returns `CallToolResult | InputRequiredResult`
367+
### `call_tool` can return `InputRequiredResult` (opt-in)
368368

369-
For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. `ClientSession.call_tool` now returns `CallToolResult | InputRequiredResult` to reflect this; narrow with `isinstance(result, CallToolResult)` before reading `.content` / `.is_error` / `.structured_content`.
370-
371-
The high-level `Client.call_tool` still returns `CallToolResult` by default (and raises `RuntimeError` if the server requests input). Pass `allow_input_required=True` to receive the `InputRequiredResult` and retry with `input_responses=` / `request_state=`.
369+
For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. By default `call_tool` (on `ClientSession`, `Client`, and `ClientSessionGroup`) still returns `CallToolResult` and raises `RuntimeError` if the server requests input. Pass `allow_input_required=True` to receive the `InputRequiredResult` instead, then retry with `input_responses=` / `request_state=`.
372370

373371
### `McpError` renamed to `MCPError`
374372

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from mcp.client.streamable_http import streamable_http_client
2626
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2727
from mcp.shared.message import SessionMessage
28-
from mcp.types import CallToolResult
2928

3029

3130
class InMemoryTokenStorage(TokenStorage):
@@ -294,7 +293,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non
294293
try:
295294
result = await self.session.call_tool(tool_name, arguments or {})
296295
print(f"\n🔧 Tool '{tool_name}' result:")
297-
if isinstance(result, CallToolResult):
296+
if hasattr(result, "content"):
298297
for content in result.content:
299298
if content.type == "text":
300299
print(content.text)

examples/clients/sse-polling-client/mcp_sse_polling_client/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import click
2121
from mcp import ClientSession
2222
from mcp.client.streamable_http import streamable_http_client
23-
from mcp.types import CallToolResult
2423

2524

2625
async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
@@ -56,7 +55,6 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
5655
)
5756

5857
print("-" * 40)
59-
assert isinstance(result, CallToolResult)
6058
if result.content:
6159
content = result.content[0]
6260
text = getattr(content, "text", str(content))

examples/snippets/clients/parsing_tool_results.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,19 @@ async def parse_tool_results():
1616

1717
# Example 1: Parsing text content
1818
result = await session.call_tool("get_data", {"format": "text"})
19-
assert isinstance(result, types.CallToolResult)
2019
for content in result.content:
2120
if isinstance(content, types.TextContent):
2221
print(f"Text: {content.text}")
2322

2423
# Example 2: Parsing structured content from JSON tools
2524
result = await session.call_tool("get_user", {"id": "123"})
26-
assert isinstance(result, types.CallToolResult)
27-
if result.structured_content:
25+
if hasattr(result, "structured_content") and result.structured_content:
2826
# Access structured data directly
2927
user_data = result.structured_content
3028
print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}")
3129

3230
# Example 3: Parsing embedded resources
3331
result = await session.call_tool("read_config", {})
34-
assert isinstance(result, types.CallToolResult)
3532
for content in result.content:
3633
if isinstance(content, types.EmbeddedResource):
3734
resource = content.resource
@@ -42,14 +39,12 @@ async def parse_tool_results():
4239

4340
# Example 4: Parsing image content
4441
result = await session.call_tool("generate_chart", {"data": [1, 2, 3]})
45-
assert isinstance(result, types.CallToolResult)
4642
for content in result.content:
4743
if isinstance(content, types.ImageContent):
4844
print(f"Image ({content.mime_type}): {len(content.data)} bytes")
4945

5046
# Example 5: Handling errors
5147
result = await session.call_tool("failing_tool", {})
52-
assert isinstance(result, types.CallToolResult)
5348
if result.is_error:
5449
print("Tool execution failed!")
5550
for content in result.content:

examples/snippets/clients/stdio_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ async def run():
6464

6565
# Call a tool (add tool from mcpserver_quickstart)
6666
result = await session.call_tool("add", arguments={"a": 5, "b": 3})
67-
assert isinstance(result, types.CallToolResult)
6867
result_unstructured = result.content[0]
6968
if isinstance(result_unstructured, types.TextContent):
7069
print(f"Tool result: {result_unstructured.text}")

examples/snippets/clients/url_elicitation_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ async def call_tool_with_error_handling(
150150
"""
151151
try:
152152
result = await session.call_tool(tool_name, arguments)
153-
assert isinstance(result, types.CallToolResult)
154153

155154
# Check if the tool returned an error in the result
156155
if result.is_error:

src/mcp/client/client.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,22 +439,17 @@ async def call_tool(
439439
RuntimeError: If the server returns an `InputRequiredResult` and
440440
``allow_input_required`` is ``False``.
441441
"""
442-
result = await self.session.call_tool(
442+
# TODO(L84): stop forwarding allow_input_required; run the MRTR auto-loop driver here (S6).
443+
return await self.session.call_tool(
443444
name=name,
444445
arguments=arguments,
445446
read_timeout_seconds=read_timeout_seconds,
446447
progress_callback=progress_callback,
447448
input_responses=input_responses,
448449
request_state=request_state,
449450
meta=meta,
451+
allow_input_required=allow_input_required,
450452
)
451-
if isinstance(result, InputRequiredResult) and not allow_input_required:
452-
# TODO(L84): replace this raise with the MRTR auto-loop driver (S6).
453-
raise RuntimeError(
454-
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
455-
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
456-
)
457-
return result
458453

459454
async def list_prompts(
460455
self,

src/mcp/client/session.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable, Mapping
55
from dataclasses import dataclass
66
from types import TracebackType
7-
from typing import Any, Protocol, cast
7+
from typing import Any, Literal, Protocol, cast, overload
88

99
import anyio
1010
import anyio.abc
@@ -602,6 +602,7 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
602602
types.EmptyResult,
603603
)
604604

605+
@overload
605606
async def call_tool(
606607
self,
607608
name: str,
@@ -612,8 +613,62 @@ async def call_tool(
612613
input_responses: types.InputResponses | None = None,
613614
request_state: str | None = None,
614615
meta: RequestParamsMeta | None = None,
616+
allow_input_required: Literal[False] = False,
617+
) -> types.CallToolResult: ...
618+
619+
@overload
620+
async def call_tool(
621+
self,
622+
name: str,
623+
arguments: dict[str, Any] | None = None,
624+
read_timeout_seconds: float | None = None,
625+
progress_callback: ProgressFnT | None = None,
626+
*,
627+
input_responses: types.InputResponses | None = None,
628+
request_state: str | None = None,
629+
meta: RequestParamsMeta | None = None,
630+
allow_input_required: Literal[True],
631+
) -> types.CallToolResult | types.InputRequiredResult: ...
632+
633+
@overload
634+
async def call_tool(
635+
self,
636+
name: str,
637+
arguments: dict[str, Any] | None = None,
638+
read_timeout_seconds: float | None = None,
639+
progress_callback: ProgressFnT | None = None,
640+
*,
641+
input_responses: types.InputResponses | None = None,
642+
request_state: str | None = None,
643+
meta: RequestParamsMeta | None = None,
644+
allow_input_required: bool,
645+
) -> types.CallToolResult | types.InputRequiredResult: ...
646+
647+
async def call_tool(
648+
self,
649+
name: str,
650+
arguments: dict[str, Any] | None = None,
651+
read_timeout_seconds: float | None = None,
652+
progress_callback: ProgressFnT | None = None,
653+
*,
654+
input_responses: types.InputResponses | None = None,
655+
request_state: str | None = None,
656+
meta: RequestParamsMeta | None = None,
657+
allow_input_required: bool = False,
615658
) -> types.CallToolResult | types.InputRequiredResult:
616-
"""Send a tools/call request with optional progress callback support."""
659+
"""Send a tools/call request with optional progress callback support.
660+
661+
Args:
662+
input_responses: Responses to a prior `InputRequiredResult.input_requests`.
663+
request_state: Opaque state echoed from a prior `InputRequiredResult`.
664+
allow_input_required: When ``False`` (default), an `InputRequiredResult`
665+
from the server raises `RuntimeError`; when ``True``, it is returned
666+
so the caller can resolve the requests and retry.
667+
668+
Raises:
669+
RuntimeError: If the server returns an `InputRequiredResult` and
670+
``allow_input_required`` is ``False``.
671+
"""
617672

618673
result = await self.send_request(
619674
types.CallToolRequest(
@@ -633,6 +688,11 @@ async def call_tool(
633688
if isinstance(result, types.CallToolResult) and not result.is_error:
634689
await self._validate_tool_result(name, result)
635690

691+
if isinstance(result, types.InputRequiredResult) and not allow_input_required:
692+
raise RuntimeError(
693+
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
694+
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
695+
)
636696
return result
637697

638698
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:

src/mcp/client/session_group.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,22 +238,16 @@ async def call_tool(
238238
"""
239239
session = self._tool_to_session[name]
240240
session_tool_name = self.tools[name].name
241-
result = await session.call_tool(
241+
return await session.call_tool(
242242
session_tool_name,
243243
arguments=arguments,
244244
read_timeout_seconds=read_timeout_seconds,
245245
progress_callback=progress_callback,
246246
input_responses=input_responses,
247247
request_state=request_state,
248248
meta=meta,
249+
allow_input_required=allow_input_required,
249250
)
250-
if isinstance(result, types.InputRequiredResult) and not allow_input_required:
251-
# TODO(L84): replace this raise with the MRTR auto-loop driver (S6).
252-
raise RuntimeError(
253-
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
254-
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
255-
)
256-
return result
257251

258252
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
259253
"""Disconnects from a single MCP server."""

0 commit comments

Comments
 (0)