diff --git a/docs/authors.md b/docs/authors.md index f225ad1a..8f3a5754 100644 --- a/docs/authors.md +++ b/docs/authors.md @@ -10,6 +10,7 @@ * [Rafael Rêgo](mailto:crafards@gmail.com) * [Raphael Schrader](mailto:raphael@schradercloud.de) * [João S. O. Bueno](mailto:gwidion@gmail.com) +* [Rodrigo Nogueira](mailto:rodrigo.b.nogueira@gmail.com) ## Scaffolding diff --git a/statemachine/signature.py b/statemachine/signature.py index 28e407be..254fee52 100644 --- a/statemachine/signature.py +++ b/statemachine/signature.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import partial from inspect import BoundArguments from inspect import Parameter @@ -6,6 +8,12 @@ from itertools import chain from types import MethodType from typing import Any +from typing import FrozenSet +from typing import Optional +from typing import Tuple + +BindCacheKey = Tuple[int, FrozenSet[str]] +BindTemplate = Tuple[Tuple[str, ...], Optional[str]] # noqa: UP007 def _make_key(method): @@ -44,6 +52,11 @@ def cached_function(cls, method): class SignatureAdapter(Signature): is_coroutine: bool = False + _bind_cache: dict[BindCacheKey, BindTemplate] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._bind_cache = {} @classmethod @signature_cache @@ -60,19 +73,57 @@ def from_callable(cls, method): adapter.is_coroutine = iscoroutinefunction(method) return adapter - def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C901 + def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: + cache_key: BindCacheKey = (len(args), frozenset(kwargs.keys())) + template = self._bind_cache.get(cache_key) + + if template is not None: + return self._fast_bind(args, kwargs, template) + + result = self._full_bind(cache_key, *args, **kwargs) + return result + + def _fast_bind( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any], + template: BindTemplate, + ) -> BoundArguments: + param_names, kwargs_param_name = template + arguments: dict[str, Any] = {} + + for i, name in enumerate(param_names): + if i < len(args): + arguments[name] = args[i] + else: + arguments[name] = kwargs.get(name) + + if kwargs_param_name is not None: + matched = set(param_names) + arguments[kwargs_param_name] = {k: v for k, v in kwargs.items() if k not in matched} + + return BoundArguments(self, arguments) # type: ignore[arg-type] + + def _full_bind( # noqa: C901 + self, + cache_key: BindCacheKey, + *args: Any, + **kwargs: Any, + ) -> BoundArguments: """Get a BoundArguments object, that maps the passed `args` and `kwargs` to the function's signature. It avoids to raise `TypeError` trying to fill all the required arguments and ignoring the unknown ones. Adapted from the internal `inspect.Signature._bind`. """ - arguments = {} + arguments: dict[str, Any] = {} + param_names_used: list[str] = [] parameters = iter(self.parameters.values()) arg_vals = iter(args) parameters_ex: Any = () kwargs_param = None + kwargs_param_name: str | None = None while True: # Let's iterate through the positional arguments and corresponding @@ -95,8 +146,7 @@ def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C elif param.name in kwargs: if param.kind == Parameter.POSITIONAL_ONLY: msg = ( - "{arg!r} parameter is positional only, " - "but was passed as a keyword" + "{arg!r} parameter is positional only, but was passed as a keyword" ) msg = msg.format(arg=param.name) raise TypeError(msg) from None @@ -141,12 +191,14 @@ def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C values = [arg_val] values.extend(arg_vals) arguments[param.name] = tuple(values) + param_names_used.append(param.name) break if param.name in kwargs and param.kind != Parameter.POSITIONAL_ONLY: arguments[param.name] = kwargs.pop(param.name) else: arguments[param.name] = arg_val + param_names_used.append(param.name) # Now, we iterate through the remaining parameters to process # keyword arguments @@ -172,14 +224,19 @@ def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C # arguments. pass else: - arguments[param_name] = arg_val # + arguments[param_name] = arg_val + param_names_used.append(param_name) if kwargs: if kwargs_param is not None: # Process our '**kwargs'-like parameter - arguments[kwargs_param.name] = kwargs # type: ignore [assignment] + arguments[kwargs_param.name] = kwargs # type: ignore[assignment] + kwargs_param_name = kwargs_param.name else: # 'ignoring we got an unexpected keyword argument' pass - return BoundArguments(self, arguments) # type: ignore [arg-type] + template: BindTemplate = (tuple(param_names_used), kwargs_param_name) + self._bind_cache[cache_key] = template + + return BoundArguments(self, arguments) # type: ignore[arg-type] diff --git a/tests/test_signature.py b/tests/test_signature.py index a36a9d59..f8d37fbf 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -2,8 +2,8 @@ from functools import partial import pytest - from statemachine.dispatcher import callable_method +from statemachine.signature import SignatureAdapter def single_positional_param(a): @@ -162,3 +162,37 @@ def test_support_for_partial(self): assert wrapped_func("A", "B") == ("A", "B", "activated") assert wrapped_func.__name__ == positional_and_kw_arguments.__name__ + + +def named_and_kwargs(source, **kwargs): + return source, kwargs + + +class TestCachedBindExpected: + """Tests that exercise the cache fast-path by calling the same + wrapped function twice with the same argument shape.""" + + def setup_method(self): + SignatureAdapter.from_callable.clear_cache() + + def test_named_param_not_leaked_into_kwargs(self): + """Named params should not appear in the **kwargs dict on cache hit.""" + wrapped = callable_method(named_and_kwargs) + + # 1st call: cache miss -> _full_bind + result1 = wrapped(source="A", target="B", event="go") + assert result1 == ("A", {"target": "B", "event": "go"}) + + # 2nd call: cache hit -> _fast_bind + result2 = wrapped(source="X", target="Y", event="stop") + assert result2 == ("X", {"target": "Y", "event": "stop"}) + + def test_kwargs_only_receives_unmatched_keys_with_positional(self): + """When mixing positional and keyword args with **kwargs.""" + wrapped = callable_method(named_and_kwargs) + + result1 = wrapped("A", target="B") + assert result1 == ("A", {"target": "B"}) + + result2 = wrapped("X", target="Y") + assert result2 == ("X", {"target": "Y"})