Skip to content

Commit 3fa1324

Browse files
committed
feat(tools): support callable class instances and bound methods
Add support for using callable class instances and bound methods with beta_tool. This enables a more natural pattern for tools that need access to instance state or request-scoped context. Previously, users had to wrap their tools in closures to capture state: def setup_tools(user_id): @beta_tool def get_user(id: int) -> str: return fetch_user(user_id, id) return [get_user] Now they can use class instances directly: class GetUser: def __init__(self, user_id): self.user_id = user_id def __call__(self, id: int) -> str: return fetch_user(self.user_id, id) tool = beta_tool(GetUser(user_id), name="get_user") Changes: - Add _normalize_callable() to extract bound __call__ from instances - Update BaseFunctionTool to normalize callables before validation - Add unit tests for callable instances and bound methods - Add integration tests with tool_runner Closes anthropics#1087
1 parent 9b5ab24 commit 3fa1324

3 files changed

Lines changed: 187 additions & 4 deletions

File tree

src/anthropic/lib/tools/_beta_functions.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from abc import ABC, abstractmethod
55
from typing import Any, Union, Generic, TypeVar, Callable, Iterable, Coroutine, cast, overload
6-
from inspect import iscoroutinefunction
6+
from inspect import ismethod, isfunction, iscoroutinefunction
77
from typing_extensions import TypeAlias, override
88

99
import pydantic
@@ -23,6 +23,33 @@
2323

2424
BetaFunctionToolResultType: TypeAlias = Union[str, Iterable[BetaContent]]
2525

26+
27+
def _normalize_callable(func: Callable[..., Any]) -> Callable[..., Any]:
28+
"""Normalize a callable to a function that can be used with pydantic.validate_call.
29+
30+
If the callable is a class instance with a __call__ method (but not a function or method),
31+
this extracts the bound __call__ method. This allows callable class instances to be used
32+
as tools without requiring manual extraction of __call__.
33+
34+
Args:
35+
func: A function, method, or callable instance
36+
37+
Returns:
38+
A function or bound method suitable for use with pydantic.validate_call
39+
"""
40+
# If it's already a function or method, use it directly
41+
if isfunction(func) or ismethod(func):
42+
return func
43+
44+
# If it's a callable instance (class with __call__), extract the bound __call__ method
45+
if callable(func):
46+
call_method = func.__call__ # pyright: ignore[reportFunctionMemberAccess] # noqa: B004
47+
if ismethod(call_method):
48+
return call_method
49+
50+
return func
51+
52+
2653
Function = Callable[..., BetaFunctionToolResultType]
2754
FunctionT = TypeVar("FunctionT", bound=Function)
2855

@@ -83,9 +110,11 @@ def __init__(
83110
if _compat.PYDANTIC_V1:
84111
raise RuntimeError("Tool functions are only supported with Pydantic v2")
85112

86-
self.func = func
87-
self._func_with_validate = pydantic.validate_call(func)
88-
self.name = name or func.__name__
113+
# Normalize callable instances to their __call__ method
114+
normalized_func = _normalize_callable(func)
115+
self.func = cast(CallableT, normalized_func)
116+
self._func_with_validate = pydantic.validate_call(normalized_func)
117+
self.name = name or normalized_func.__name__
89118
self._defer_loading = defer_loading
90119

91120
self.description = description or self._get_description_from_docstring()

tests/lib/tools/test_functions.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,80 @@ def simple_add(a: int, b: int) -> str:
439439
assert function_tool.input_schema == expected_schema
440440

441441

442+
def test_callable_class_instance(self) -> None:
443+
"""Test that callable class instances can be used as tools."""
444+
445+
class FetchProduct:
446+
def __init__(self, ctx: dict[str, str]) -> None:
447+
self.ctx = ctx
448+
449+
def __call__(self, product_id: int) -> str:
450+
"""Fetch a product by ID."""
451+
return f"Product {product_id} from {self.ctx['session']}"
452+
453+
instance = FetchProduct({"session": "test-session"})
454+
tool = beta_tool(instance, name="fetch_product")
455+
456+
assert tool.name == "fetch_product"
457+
assert tool.description == "Fetch a product by ID."
458+
assert tool.call({"product_id": 123}) == "Product 123 from test-session"
459+
460+
# Check schema
461+
expected_schema = {
462+
"additionalProperties": False,
463+
"type": "object",
464+
"properties": {
465+
"product_id": {"title": "Product Id", "type": "integer"},
466+
},
467+
"required": ["product_id"],
468+
}
469+
assert tool.input_schema == expected_schema
470+
471+
def test_bound_method(self) -> None:
472+
"""Test that bound methods can be used as tools."""
473+
474+
class WeatherService:
475+
def __init__(self, api_key: str) -> None:
476+
self.api_key = api_key
477+
478+
def get_weather(self, location: str) -> str:
479+
"""Get weather for a location."""
480+
return f"Weather in {location} (using key: {self.api_key[:4]}...)"
481+
482+
service = WeatherService("secret-api-key")
483+
tool = beta_tool(service.get_weather, name="weather")
484+
485+
assert tool.name == "weather"
486+
assert tool.description == "Get weather for a location."
487+
assert tool.call({"location": "London"}) == "Weather in London (using key: secr...)"
488+
489+
# Check schema
490+
expected_schema = {
491+
"additionalProperties": False,
492+
"type": "object",
493+
"properties": {
494+
"location": {"title": "Location", "type": "string"},
495+
},
496+
"required": ["location"],
497+
}
498+
assert tool.input_schema == expected_schema
499+
500+
def test_callable_class_without_explicit_name(self) -> None:
501+
"""Test that callable class instances infer name from __call__ method."""
502+
503+
class MyTool:
504+
def __call__(self, x: int) -> str:
505+
"""Process x."""
506+
return str(x)
507+
508+
instance = MyTool()
509+
tool = beta_tool(instance)
510+
511+
# Should use __call__ as the name since that's the actual method
512+
assert tool.name == "__call__"
513+
assert tool.description == "Process x."
514+
515+
442516
def _get_parameters_info(fn: BaseFunctionTool[Any]) -> dict[str, str]:
443517
param_info: dict[str, str] = {}
444518
for param in fn._parsed_docstring.params:

tests/lib/tools/test_runners.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,86 @@ def tool_runner(client: Anthropic) -> BetaToolRunner[None]:
547547
respx_mock=respx_mock,
548548
)
549549

550+
@pytest.mark.respx(base_url=base_url)
551+
def test_callable_class_instance_tool(
552+
self, client: Anthropic, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
553+
) -> None:
554+
"""Test that callable class instances work with tool_runner."""
555+
556+
class WeatherTool:
557+
def __init__(self, default_units: str) -> None:
558+
self.default_units = default_units
559+
560+
def __call__(self, location: str, units: Literal["c", "f"] = "c") -> str:
561+
"""Lookup the weather for a given city in either celsius or fahrenheit
562+
563+
Args:
564+
location: The city and state, e.g. San Francisco, CA
565+
units: Unit for the output, either 'c' for celsius or 'f' for fahrenheit
566+
Returns:
567+
A dictionary containing the location, temperature, and weather condition.
568+
"""
569+
actual_units = units or self.default_units
570+
return json.dumps(_get_weather(location, actual_units))
571+
572+
weather_instance = WeatherTool(default_units="f")
573+
weather_tool = beta_tool(weather_instance, name="get_weather")
574+
575+
message = make_snapshot_request(
576+
lambda c: c.beta.messages.tool_runner(
577+
max_tokens=1024,
578+
model="claude-haiku-4-5",
579+
tools=[weather_tool],
580+
messages=[{"role": "user", "content": "What is the weather in SF?"}],
581+
).until_done(),
582+
content_snapshot=snapshots["basic"]["responses"],
583+
path="/v1/messages",
584+
mock_client=client,
585+
respx_mock=respx_mock,
586+
)
587+
588+
assert print_obj(message, monkeypatch) == snapshots["basic"]["result"]
589+
590+
@pytest.mark.respx(base_url=base_url)
591+
def test_bound_method_tool(
592+
self, client: Anthropic, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
593+
) -> None:
594+
"""Test that bound methods work with tool_runner."""
595+
596+
class WeatherService:
597+
def __init__(self, api_key: str) -> None:
598+
self.api_key = api_key
599+
600+
def get_weather(self, location: str, units: Literal["c", "f"]) -> str:
601+
"""Lookup the weather for a given city in either celsius or fahrenheit
602+
603+
Args:
604+
location: The city and state, e.g. San Francisco, CA
605+
units: Unit for the output, either 'c' for celsius or 'f' for fahrenheit
606+
Returns:
607+
A dictionary containing the location, temperature, and weather condition.
608+
"""
609+
# In a real scenario, self.api_key would be used
610+
return json.dumps(_get_weather(location, units))
611+
612+
service = WeatherService(api_key="secret-key")
613+
weather_tool = beta_tool(service.get_weather, name="get_weather")
614+
615+
message = make_snapshot_request(
616+
lambda c: c.beta.messages.tool_runner(
617+
max_tokens=1024,
618+
model="claude-haiku-4-5",
619+
tools=[weather_tool],
620+
messages=[{"role": "user", "content": "What is the weather in SF?"}],
621+
).until_done(),
622+
content_snapshot=snapshots["basic"]["responses"],
623+
path="/v1/messages",
624+
mock_client=client,
625+
respx_mock=respx_mock,
626+
)
627+
628+
assert print_obj(message, monkeypatch) == snapshots["basic"]["result"]
629+
550630

551631
@pytest.mark.skipif(PYDANTIC_V1, reason="tool runner not supported with pydantic v1")
552632
@pytest.mark.respx(base_url=base_url)

0 commit comments

Comments
 (0)