Skip to content

Commit d9a6317

Browse files
committed
Client call_tool: accept input_responses/request_state; return InputRequiredResult via allow_input_required opt-in
- ClientSession.send_request: accept TypeAdapter[T] alongside type[T] for result_type so callers can parse union results. - ClientSession.call_tool (mechanics): add input_responses= / request_state= retry kwargs; return CallToolResult | InputRequiredResult; gate output-schema validation on isinstance(result, CallToolResult). - Client.call_tool / ClientSessionGroup.call_tool (policy): @overload on allow_input_required — Literal[False] (default) returns CallToolResult; Literal[True] returns the union. Default raises RuntimeError on InputRequiredResult with a retry steer (TODO(L80) marks where the auto-loop driver replaces this). - Examples and tests that call ClientSession.call_tool directly narrow with isinstance(result, CallToolResult); README.v2.md regenerated from snippets.
1 parent f226d00 commit d9a6317

15 files changed

Lines changed: 239 additions & 17 deletions

File tree

README.v2.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,7 @@ async def run():
21662166

21672167
# Call a tool (add tool from mcpserver_quickstart)
21682168
result = await session.call_tool("add", arguments={"a": 5, "b": 3})
2169+
assert isinstance(result, types.CallToolResult)
21692170
result_unstructured = result.content[0]
21702171
if isinstance(result_unstructured, types.TextContent):
21712172
print(f"Tool result: {result_unstructured.text}")
@@ -2431,19 +2432,22 @@ async def parse_tool_results():
24312432

24322433
# Example 1: Parsing text content
24332434
result = await session.call_tool("get_data", {"format": "text"})
2435+
assert isinstance(result, types.CallToolResult)
24342436
for content in result.content:
24352437
if isinstance(content, types.TextContent):
24362438
print(f"Text: {content.text}")
24372439

24382440
# Example 2: Parsing structured content from JSON tools
24392441
result = await session.call_tool("get_user", {"id": "123"})
2440-
if hasattr(result, "structured_content") and result.structured_content:
2442+
assert isinstance(result, types.CallToolResult)
2443+
if result.structured_content:
24412444
# Access structured data directly
24422445
user_data = result.structured_content
24432446
print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}")
24442447

24452448
# Example 3: Parsing embedded resources
24462449
result = await session.call_tool("read_config", {})
2450+
assert isinstance(result, types.CallToolResult)
24472451
for content in result.content:
24482452
if isinstance(content, types.EmbeddedResource):
24492453
resource = content.resource
@@ -2454,12 +2458,14 @@ async def parse_tool_results():
24542458

24552459
# Example 4: Parsing image content
24562460
result = await session.call_tool("generate_chart", {"data": [1, 2, 3]})
2461+
assert isinstance(result, types.CallToolResult)
24572462
for content in result.content:
24582463
if isinstance(content, types.ImageContent):
24592464
print(f"Image ({content.mime_type}): {len(content.data)} bytes")
24602465

24612466
# Example 5: Handling errors
24622467
result = await session.call_tool("failing_tool", {})
2468+
assert isinstance(result, types.CallToolResult)
24632469
if result.is_error:
24642470
print("Tool execution failed!")
24652471
for content in result.content:

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mcp.client.streamable_http import streamable_http_client
2626
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2727
from mcp.shared.message import SessionMessage
28+
from mcp.types import CallToolResult
2829

2930

3031
class InMemoryTokenStorage(TokenStorage):
@@ -293,7 +294,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non
293294
try:
294295
result = await self.session.call_tool(tool_name, arguments or {})
295296
print(f"\n🔧 Tool '{tool_name}' result:")
296-
if hasattr(result, "content"):
297+
if isinstance(result, CallToolResult):
297298
for content in result.content:
298299
if content.type == "text":
299300
print(content.text)

examples/clients/sse-polling-client/mcp_sse_polling_client/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import click
2121
from mcp import ClientSession
2222
from mcp.client.streamable_http import streamable_http_client
23+
from mcp.types import CallToolResult
2324

2425

2526
async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
@@ -55,6 +56,7 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
5556
)
5657

5758
print("-" * 40)
59+
assert isinstance(result, CallToolResult)
5860
if result.content:
5961
content = result.content[0]
6062
text = getattr(content, "text", str(content))

examples/snippets/clients/parsing_tool_results.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,22 @@ async def parse_tool_results():
1616

1717
# Example 1: Parsing text content
1818
result = await session.call_tool("get_data", {"format": "text"})
19+
assert isinstance(result, types.CallToolResult)
1920
for content in result.content:
2021
if isinstance(content, types.TextContent):
2122
print(f"Text: {content.text}")
2223

2324
# Example 2: Parsing structured content from JSON tools
2425
result = await session.call_tool("get_user", {"id": "123"})
25-
if hasattr(result, "structured_content") and result.structured_content:
26+
assert isinstance(result, types.CallToolResult)
27+
if result.structured_content:
2628
# Access structured data directly
2729
user_data = result.structured_content
2830
print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}")
2931

3032
# Example 3: Parsing embedded resources
3133
result = await session.call_tool("read_config", {})
34+
assert isinstance(result, types.CallToolResult)
3235
for content in result.content:
3336
if isinstance(content, types.EmbeddedResource):
3437
resource = content.resource
@@ -39,12 +42,14 @@ async def parse_tool_results():
3942

4043
# Example 4: Parsing image content
4144
result = await session.call_tool("generate_chart", {"data": [1, 2, 3]})
45+
assert isinstance(result, types.CallToolResult)
4246
for content in result.content:
4347
if isinstance(content, types.ImageContent):
4448
print(f"Image ({content.mime_type}): {len(content.data)} bytes")
4549

4650
# Example 5: Handling errors
4751
result = await session.call_tool("failing_tool", {})
52+
assert isinstance(result, types.CallToolResult)
4853
if result.is_error:
4954
print("Tool execution failed!")
5055
for content in result.content:

examples/snippets/clients/stdio_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async def run():
6464

6565
# Call a tool (add tool from mcpserver_quickstart)
6666
result = await session.call_tool("add", arguments={"a": 5, "b": 3})
67+
assert isinstance(result, types.CallToolResult)
6768
result_unstructured = result.content[0]
6869
if isinstance(result_unstructured, types.TextContent):
6970
print(f"Tool result: {result_unstructured.text}")

examples/snippets/clients/url_elicitation_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ async def call_tool_with_error_handling(
150150
"""
151151
try:
152152
result = await session.call_tool(tool_name, arguments)
153+
assert isinstance(result, types.CallToolResult)
153154

154155
# Check if the tool returned an error in the result
155156
if result.is_error:

src/mcp/client/client.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Awaitable, Callable, Mapping
66
from contextlib import AsyncExitStack
77
from dataclasses import KW_ONLY, dataclass, field
8-
from typing import Any, Literal, TypeVar
8+
from typing import Any, Literal, TypeVar, overload
99

1010
import anyio
1111
from typing_extensions import deprecated
@@ -30,6 +30,8 @@
3030
EmptyResult,
3131
GetPromptResult,
3232
Implementation,
33+
InputRequiredResult,
34+
InputResponses,
3335
ListPromptsResult,
3436
ListResourcesResult,
3537
ListResourceTemplatesResult,
@@ -374,34 +376,85 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
374376
"""Unsubscribe from resource updates."""
375377
return await self.session.unsubscribe_resource(uri, meta=meta)
376378

379+
@overload
377380
async def call_tool(
378381
self,
379382
name: str,
380383
arguments: dict[str, Any] | None = None,
381384
read_timeout_seconds: float | None = None,
382385
progress_callback: ProgressFnT | None = None,
383386
*,
387+
input_responses: InputResponses | None = None,
388+
request_state: str | None = None,
384389
meta: RequestParamsMeta | None = None,
385-
) -> CallToolResult:
390+
allow_input_required: Literal[False] = False,
391+
) -> CallToolResult: ...
392+
393+
@overload
394+
async def call_tool(
395+
self,
396+
name: str,
397+
arguments: dict[str, Any] | None = None,
398+
read_timeout_seconds: float | None = None,
399+
progress_callback: ProgressFnT | None = None,
400+
*,
401+
input_responses: InputResponses | None = None,
402+
request_state: str | None = None,
403+
meta: RequestParamsMeta | None = None,
404+
allow_input_required: Literal[True],
405+
) -> CallToolResult | InputRequiredResult: ...
406+
407+
async def call_tool(
408+
self,
409+
name: str,
410+
arguments: dict[str, Any] | None = None,
411+
read_timeout_seconds: float | None = None,
412+
progress_callback: ProgressFnT | None = None,
413+
*,
414+
input_responses: InputResponses | None = None,
415+
request_state: str | None = None,
416+
meta: RequestParamsMeta | None = None,
417+
allow_input_required: bool = False,
418+
) -> CallToolResult | InputRequiredResult:
386419
"""Call a tool on the server.
387420
388421
Args:
389422
name: The name of the tool to call
390423
arguments: Arguments to pass to the tool
391424
read_timeout_seconds: Timeout for the tool call
392425
progress_callback: Callback for progress updates
426+
input_responses: Responses to a prior `InputRequiredResult.input_requests`
427+
request_state: Opaque state echoed from a prior `InputRequiredResult`
393428
meta: Additional metadata for the request
429+
allow_input_required: When ``False`` (default), an `InputRequiredResult`
430+
from the server raises `RuntimeError`; when ``True``, it is returned
431+
so the caller can resolve the requests and retry.
394432
395433
Returns:
396-
The tool result.
434+
The tool result. When ``allow_input_required=True``, may instead be an
435+
`InputRequiredResult` carrying the server's input requests and opaque
436+
``request_state`` for the retry.
437+
438+
Raises:
439+
RuntimeError: If the server returns an `InputRequiredResult` and
440+
``allow_input_required`` is ``False``.
397441
"""
398-
return await self.session.call_tool(
442+
result = await self.session.call_tool(
399443
name=name,
400444
arguments=arguments,
401445
read_timeout_seconds=read_timeout_seconds,
402446
progress_callback=progress_callback,
447+
input_responses=input_responses,
448+
request_state=request_state,
403449
meta=meta,
404450
)
451+
if isinstance(result, InputRequiredResult) and not allow_input_required:
452+
# TODO(L80): replace this raise with the MRTR auto-loop driver (S6).
453+
raise RuntimeError(
454+
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
455+
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
456+
)
457+
return result
405458

406459
async def list_prompts(
407460
self,

src/mcp/client/session.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ async def _default_logging_callback(
173173

174174
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
175175

176+
_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter(
177+
types.CallToolResult | types.InputRequiredResult
178+
)
179+
176180

177181
class ClientSession:
178182
"""Client half of an MCP connection, running on a `Dispatcher`.
@@ -269,7 +273,7 @@ async def __aexit__(
269273
async def send_request(
270274
self,
271275
request: types.ClientRequest,
272-
result_type: type[ReceiveResultT],
276+
result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT],
273277
request_read_timeout_seconds: float | None = None,
274278
metadata: ClientMessageMetadata | None = None,
275279
progress_callback: ProgressFnT | None = None,
@@ -308,6 +312,8 @@ async def send_request(
308312
_methods.validate_server_result(method, version, raw)
309313
except KeyError:
310314
pass
315+
if isinstance(result_type, TypeAdapter):
316+
return result_type.validate_python(raw)
311317
return result_type.model_validate(raw, by_name=False)
312318

313319
async def send_notification(self, notification: types.ClientNotification) -> None:
@@ -603,20 +609,28 @@ async def call_tool(
603609
read_timeout_seconds: float | None = None,
604610
progress_callback: ProgressFnT | None = None,
605611
*,
612+
input_responses: types.InputResponses | None = None,
613+
request_state: str | None = None,
606614
meta: RequestParamsMeta | None = None,
607-
) -> types.CallToolResult:
615+
) -> types.CallToolResult | types.InputRequiredResult:
608616
"""Send a tools/call request with optional progress callback support."""
609617

610618
result = await self.send_request(
611619
types.CallToolRequest(
612-
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta),
620+
params=types.CallToolRequestParams(
621+
name=name,
622+
arguments=arguments,
623+
input_responses=input_responses,
624+
request_state=request_state,
625+
_meta=meta,
626+
),
613627
),
614-
types.CallToolResult,
628+
_CallToolResultAdapter,
615629
request_read_timeout_seconds=read_timeout_seconds,
616630
progress_callback=progress_callback,
617631
)
618632

619-
if not result.is_error:
633+
if isinstance(result, types.CallToolResult) and not result.is_error:
620634
await self._validate_tool_result(name, result)
621635

622636
return result

src/mcp/client/session_group.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from dataclasses import dataclass
1313
from types import TracebackType
14-
from typing import Any, TypeAlias
14+
from typing import Any, Literal, TypeAlias, overload
1515

1616
import anyio
1717
import httpx
@@ -190,25 +190,70 @@ def tools(self) -> dict[str, types.Tool]:
190190
"""Returns the tools as a dictionary of names to tools."""
191191
return self._tools
192192

193+
@overload
193194
async def call_tool(
194195
self,
195196
name: str,
196197
arguments: dict[str, Any] | None = None,
197198
read_timeout_seconds: float | None = None,
198199
progress_callback: ProgressFnT | None = None,
199200
*,
201+
input_responses: types.InputResponses | None = None,
202+
request_state: str | None = None,
200203
meta: types.RequestParamsMeta | None = None,
201-
) -> types.CallToolResult:
202-
"""Executes a tool given its name and arguments."""
204+
allow_input_required: Literal[False] = False,
205+
) -> types.CallToolResult: ...
206+
207+
@overload
208+
async def call_tool(
209+
self,
210+
name: str,
211+
arguments: dict[str, Any] | None = None,
212+
read_timeout_seconds: float | None = None,
213+
progress_callback: ProgressFnT | None = None,
214+
*,
215+
input_responses: types.InputResponses | None = None,
216+
request_state: str | None = None,
217+
meta: types.RequestParamsMeta | None = None,
218+
allow_input_required: Literal[True],
219+
) -> types.CallToolResult | types.InputRequiredResult: ...
220+
221+
async def call_tool(
222+
self,
223+
name: str,
224+
arguments: dict[str, Any] | None = None,
225+
read_timeout_seconds: float | None = None,
226+
progress_callback: ProgressFnT | None = None,
227+
*,
228+
input_responses: types.InputResponses | None = None,
229+
request_state: str | None = None,
230+
meta: types.RequestParamsMeta | None = None,
231+
allow_input_required: bool = False,
232+
) -> types.CallToolResult | types.InputRequiredResult:
233+
"""Executes a tool given its name and arguments.
234+
235+
Raises:
236+
RuntimeError: If the server returns an `InputRequiredResult` and
237+
``allow_input_required`` is ``False``.
238+
"""
203239
session = self._tool_to_session[name]
204240
session_tool_name = self.tools[name].name
205-
return await session.call_tool(
241+
result = await session.call_tool(
206242
session_tool_name,
207243
arguments=arguments,
208244
read_timeout_seconds=read_timeout_seconds,
209245
progress_callback=progress_callback,
246+
input_responses=input_responses,
247+
request_state=request_state,
210248
meta=meta,
211249
)
250+
if isinstance(result, types.InputRequiredResult) and not allow_input_required:
251+
# TODO(L80): replace this raise with the MRTR auto-loop driver (S6).
252+
raise RuntimeError(
253+
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
254+
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
255+
)
256+
return result
212257

213258
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
214259
"""Disconnects from a single MCP server."""

tests/client/test_http_unicode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ async def test_streamable_http_client_unicode_tool_call() -> None:
142142
# Test 2: Send Unicode text in tool call (client→server→client)
143143
for test_name, test_string in UNICODE_TEST_STRINGS.items():
144144
result = await session.call_tool("echo_unicode", arguments={"text": test_string})
145+
assert isinstance(result, types.CallToolResult)
145146

146147
# Verify server correctly received and echoed back Unicode
147148
assert len(result.content) == 1

0 commit comments

Comments
 (0)