Skip to content

Commit 58238b1

Browse files
committed
Memoize built-in bound-method resolvers; stop mutating pre_validated
Review follow-ups on #2969: - _resolver_key now keys any bound method (pure-python or built-in) by its underlying function/name plus __self__ identity, so a built-in bound method (no __func__, fresh object each access) referenced twice still memoizes to one call. - call_fn_with_arg_validation copies the validated args before merging the injected kwargs, so a caller-provided pre_validated dict is never mutated. Add regression tests.
1 parent 37c038c commit 58238b1

4 files changed

Lines changed: 62 additions & 8 deletions

File tree

src/mcp/server/mcpserver/resolve.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,20 @@ def _wants_union(type_arg: Any) -> bool:
134134
def _resolver_key(fn: Callable[..., Any]) -> Hashable:
135135
"""Identity key for memoizing a resolver.
136136
137-
A bound method is recreated on each attribute access (`id(auth.login)` differs
138-
every time), so key it by `(id(__func__), id(__self__))` to keep `auth.login`
139-
referenced in two places memoized to one call. Everything else keys by `id`,
140-
so two distinct callables never collide even if they compare equal.
137+
A bound method - pure-python (`inspect.ismethod`) or built-in (e.g. `obj.meth`
138+
on a C-extension type) - is recreated on each attribute access, so `id(fn)`
139+
differs every time. Key it by its underlying function (or name) plus its
140+
`__self__` identity so `auth.login` referenced in two places memoizes to one
141+
call. Everything else keys by `id`, so two distinct callables never collide
142+
even if they compare equal.
141143
"""
142-
if inspect.ismethod(fn):
143-
return (id(fn.__func__), id(fn.__self__))
144+
bound_self = getattr(fn, "__self__", None)
145+
if bound_self is not None:
146+
# `__func__` (pure-python) has a stable identity; built-ins expose only a
147+
# stable `__name__`. Use the function's id or the name's value accordingly.
148+
func = getattr(fn, "__func__", None)
149+
underlying: Hashable = id(func) if func is not None else getattr(fn, "__name__", id(fn))
150+
return (underlying, id(bound_self))
144151
return id(fn)
145152

146153

src/mcp/server/mcpserver/utilities/func_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ async def call_fn_with_arg_validation(
9191
validating twice can re-run `default_factory`/stateful validators and hand the
9292
function different values than a caller already observed.
9393
"""
94-
arguments_parsed_dict = (
94+
# Copy so a caller-provided `pre_validated` dict is never mutated in place.
95+
arguments_parsed_dict = dict(
9596
pre_validated if pre_validated is not None else self.validate_arguments(arguments_to_validate)
9697
)
9798

tests/server/mcpserver/test_func_metadata.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,28 @@ async def test_complex_function_runtime_arg_validation_with_json():
155155
assert result == "ok!"
156156

157157

158+
@pytest.mark.anyio
159+
async def test_call_fn_does_not_mutate_pre_validated():
160+
"""A caller-provided `pre_validated` dict must not be mutated by the call."""
161+
162+
def fn(x: int, ctx: str) -> str:
163+
return f"{x}:{ctx}"
164+
165+
meta = func_metadata(fn, skip_names=["ctx"])
166+
pre_validated = meta.validate_arguments({"x": 1})
167+
snapshot = dict(pre_validated)
168+
169+
result = await meta.call_fn_with_arg_validation(
170+
fn,
171+
fn_is_async=False,
172+
arguments_to_validate={"x": 1},
173+
arguments_to_pass_directly={"ctx": "injected"},
174+
pre_validated=pre_validated,
175+
)
176+
assert result == "1:injected"
177+
assert pre_validated == snapshot # `ctx` was not leaked into the caller's dict
178+
179+
158180
def test_str_vs_list_str():
159181
"""Test handling of string vs list[str] type annotations.
160182

tests/server/mcpserver/test_resolve.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Resolve,
1818
)
1919
from mcp.server.mcpserver.exceptions import InvalidSignature
20-
from mcp.server.mcpserver.resolve import find_resolved_parameters
20+
from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters
2121
from mcp.server.mcpserver.tools.base import Tool
2222
from mcp.types import ElicitRequestParams, ElicitResult, TextContent
2323

@@ -432,3 +432,27 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E
432432
async with Client(mcp, mode="legacy", elicitation_callback=never) as client:
433433
assert await _text(client, "run", {}) == "1:1"
434434
assert counter["n"] == 1
435+
436+
437+
def test_resolver_key_is_stable_for_methods_and_distinct_callables():
438+
class Service:
439+
def handler(self) -> None: ... # pragma: no cover
440+
441+
a, b = Service(), Service()
442+
443+
# Pure-python bound methods: stable across accesses, distinct per instance.
444+
assert _resolver_key(a.handler) == _resolver_key(a.handler)
445+
assert _resolver_key(a.handler) != _resolver_key(b.handler)
446+
447+
# Built-in bound methods (no `__func__`): fresh object each access, but the key
448+
# is stable and keyed to `__self__`.
449+
items: list[int] = []
450+
others: list[int] = []
451+
assert _resolver_key(items.append) == _resolver_key(items.append)
452+
assert _resolver_key(items.append) != _resolver_key(others.append)
453+
assert _resolver_key(items.append) != _resolver_key(items.pop)
454+
455+
# Plain functions key by identity.
456+
def fn() -> None: ... # pragma: no cover
457+
458+
assert _resolver_key(fn) == _resolver_key(fn)

0 commit comments

Comments
 (0)