Skip to content

Commit 37c038c

Browse files
committed
Validate resolver tool args once; key resolvers by method identity
Review follow-ups on #2969: - Tool.run validated arguments twice when resolvers were present (once to feed resolvers, once in call_fn_with_arg_validation). A field with default_factory or a stateful validator could hand a by-name resolver a different value than the tool body. Validate once and pass it through via a new pre_validated argument so both observe the same value. - Key the resolver cache/plans by (id(__func__), id(__self__)) for bound methods and id(fn) otherwise, instead of the callable's equality, so two distinct callables that compare equal can no longer share a plan/cache entry while bound-method memoization still works. Add regression tests.
1 parent aac86dc commit 37c038c

4 files changed

Lines changed: 55 additions & 13 deletions

File tree

src/mcp/server/mcpserver/resolve.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,16 @@ def _wants_union(type_arg: Any) -> bool:
132132

133133

134134
def _resolver_key(fn: Callable[..., Any]) -> Hashable:
135-
"""Stable, equality-based key for memoizing a resolver.
135+
"""Identity key for memoizing a resolver.
136136
137-
Bound methods are recreated on each attribute access (`id(auth.login)` differs
138-
every time) but hash/compare by `(__func__, __self__)`, so the callable itself
139-
is the right key. Falls back to `id` only for the rare unhashable callable.
137+
A bound method is recreated on each attribute access (`id(auth.login)` differs
138+
every time), so key it by `(id(__func__), id(__self__))` to keep `auth.login`
139+
referenced in two places memoized to one call. Everything else keys by `id`,
140+
so two distinct callables never collide even if they compare equal.
140141
"""
141-
try:
142-
hash(fn)
143-
except TypeError: # pragma: no cover - unhashable callables are pathological
144-
return id(fn)
145-
return fn
142+
if inspect.ismethod(fn):
143+
return (id(fn.__func__), id(fn.__self__))
144+
return id(fn)
146145

147146

148147
def build_resolver_plans(

src/mcp/server/mcpserver/tools/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,23 @@ async def run(
128128
pass_directly: dict[str, Any] = {}
129129
if self.context_kwarg is not None:
130130
pass_directly[self.context_kwarg] = context
131+
132+
# Resolvers see the same validated arguments the tool body receives:
133+
# validate once and reuse it, so a `default_factory`/stateful validator
134+
# can't hand a by-name resolver a different value than the body.
135+
pre_validated: dict[str, Any] | None = None
131136
if self.resolved_params:
132-
tool_args = self.fn_metadata.validate_arguments(arguments)
133-
pass_directly |= await resolve_arguments(self.resolved_params, self.resolver_plans, tool_args, context)
137+
pre_validated = self.fn_metadata.validate_arguments(arguments)
138+
pass_directly |= await resolve_arguments(
139+
self.resolved_params, self.resolver_plans, pre_validated, context
140+
)
134141

135142
result = await self.fn_metadata.call_fn_with_arg_validation(
136143
self.fn,
137144
self.is_async,
138145
arguments,
139146
pass_directly or None,
147+
pre_validated=pre_validated,
140148
)
141149

142150
if convert_result:

src/mcp/server/mcpserver/utilities/func_metadata.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,19 @@ async def call_fn_with_arg_validation(
8181
fn_is_async: bool,
8282
arguments_to_validate: dict[str, Any],
8383
arguments_to_pass_directly: dict[str, Any] | None,
84+
pre_validated: dict[str, Any] | None = None,
8485
) -> Any:
8586
"""Call the given function with arguments validated and injected.
8687
8788
Arguments are first attempted to be parsed from JSON, then validated against
88-
the argument model, before being passed to the function.
89+
the argument model, before being passed to the function. Pass `pre_validated`
90+
(the output of `validate_arguments`) to reuse an earlier validation pass -
91+
validating twice can re-run `default_factory`/stateful validators and hand the
92+
function different values than a caller already observed.
8993
"""
90-
arguments_parsed_dict = self.validate_arguments(arguments_to_validate)
94+
arguments_parsed_dict = (
95+
pre_validated if pre_validated is not None else self.validate_arguments(arguments_to_validate)
96+
)
9197

9298
arguments_parsed_dict |= arguments_to_pass_directly or {}
9399

tests/server/mcpserver/test_resolve.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,32 @@ async def tool(value: Annotated[Login, Resolve(service.a)]) -> str:
403403

404404
with pytest.raises(InvalidSignature, match="cyclic"):
405405
Tool.from_function(tool)
406+
407+
408+
@pytest.mark.anyio
409+
async def test_resolver_and_body_see_the_same_validated_default():
410+
mcp = MCPServer(name="DefaultFactory")
411+
counter = {"n": 0}
412+
413+
def next_id() -> int:
414+
counter["n"] += 1
415+
return counter["n"]
416+
417+
# A by-name resolver and the tool body must observe one validation pass, so the
418+
# `default_factory` runs once and both see the same generated value.
419+
async def echo_id(request_id: int) -> int:
420+
return request_id
421+
422+
@mcp.tool()
423+
async def run(
424+
request_id: Annotated[int, Field(default_factory=next_id)],
425+
resolved_id: Annotated[int, Resolve(echo_id)],
426+
) -> str:
427+
return f"{request_id}:{resolved_id}"
428+
429+
async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover
430+
raise AssertionError("should not elicit")
431+
432+
async with Client(mcp, mode="legacy", elicitation_callback=never) as client:
433+
assert await _text(client, "run", {}) == "1:1"
434+
assert counter["n"] == 1

0 commit comments

Comments
 (0)