Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions statemachine/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Tuple

BindCacheKey = Tuple[int, FrozenSet[str]]
BindTemplate = Tuple[Tuple[str, ...], Optional[str]] # noqa: UP007
BindTemplate = Tuple[Tuple[str, ...], Optional[str], Optional[str]] # noqa: UP007


def _make_key(method):
Expand Down Expand Up @@ -89,12 +89,25 @@ def _fast_bind(
kwargs: dict[str, Any],
template: BindTemplate,
) -> BoundArguments:
param_names, kwargs_param_name = template
param_names, kwargs_param_name, var_positional_name = template
arguments: dict[str, Any] = {}
past_var_positional = False

for i, name in enumerate(param_names):
if i < len(args):
arguments[name] = args[i]
if name == var_positional_name:
# Collect all remaining positional args into a tuple
arguments[name] = args[i:]
past_var_positional = True
elif past_var_positional:
# After *args, remaining params are keyword-only
arguments[name] = kwargs.get(name)
elif i < len(args):
# Match _full_bind: if param is also in kwargs, kwargs wins
# (POSITIONAL_OR_KEYWORD params prefer kwargs over positional args)
if name in kwargs:
arguments[name] = kwargs[name]
else:
arguments[name] = args[i]
else:
arguments[name] = kwargs.get(name)

Expand Down Expand Up @@ -124,6 +137,7 @@ def _full_bind( # noqa: C901
parameters_ex: Any = ()
kwargs_param = None
kwargs_param_name: str | None = None
var_positional_name: str | None = None

while True:
# Let's iterate through the positional arguments and corresponding
Expand Down Expand Up @@ -192,6 +206,7 @@ def _full_bind( # noqa: C901
values.extend(arg_vals)
arguments[param.name] = tuple(values)
param_names_used.append(param.name)
var_positional_name = param.name
break

if param.name in kwargs and param.kind != Parameter.POSITIONAL_ONLY:
Expand Down Expand Up @@ -236,7 +251,7 @@ def _full_bind( # noqa: C901
# 'ignoring we got an unexpected keyword argument'
pass

template: BindTemplate = (tuple(param_names_used), kwargs_param_name)
template: BindTemplate = (tuple(param_names_used), kwargs_param_name, var_positional_name)
self._bind_cache[cache_key] = template

return BoundArguments(self, arguments) # type: ignore[arg-type]
95 changes: 95 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial

import pytest

from statemachine.dispatcher import callable_method
from statemachine.signature import SignatureAdapter

Expand Down Expand Up @@ -196,3 +197,97 @@ def test_kwargs_only_receives_unmatched_keys_with_positional(self):

result2 = wrapped("X", target="Y")
assert result2 == ("X", {"target": "Y"})

def test_var_positional_collected_as_tuple(self):
"""VAR_POSITIONAL (*args) must be collected into a tuple on cache hit."""

def fn(*args, **kwargs):
return args, kwargs

wrapped = callable_method(fn)

result1 = wrapped(1, 2, 3, key="val")
assert result1 == ((1, 2, 3), {"key": "val"})

result2 = wrapped(4, 5, key="other")
assert result2 == ((4, 5), {"key": "other"})

def test_keyword_only_after_var_positional(self):
"""KEYWORD_ONLY params after *args must be extracted from kwargs on cache hit."""

def fn(*args, event, **kwargs):
return args, event, kwargs

wrapped = callable_method(fn)

result1 = wrapped(100, event="ev1", source="s0")
assert result1 == ((100,), "ev1", {"source": "s0"})

result2 = wrapped(200, event="ev2", source="s1")
assert result2 == ((200,), "ev2", {"source": "s1"})

def test_positional_or_keyword_prefers_kwargs_over_positional(self):
"""When a POSITIONAL_OR_KEYWORD param is in both args and kwargs, kwargs wins."""

def fn(event, source, target):
return event, source, target

wrapped = callable_method(fn)

# 1st call: positional arg provided but 'event' also in kwargs -> kwargs wins
result1 = wrapped("discarded_content", event="ev1", source="s0", target="t0")
assert result1 == ("ev1", "s0", "t0")

# 2nd call: cache hit, same behavior expected
result2 = wrapped("other_content", event="ev2", source="s1", target="t1")
assert result2 == ("ev2", "s1", "t1")

def test_empty_var_positional(self):
"""Empty *args is handled correctly on cache hit."""

def fn(*args, **kwargs):
return args, kwargs

wrapped = callable_method(fn)

# 1st call with args
result1 = wrapped(1, key="val")
assert result1 == ((1,), {"key": "val"})

# 2nd call: only kwargs, no positional args — different cache key (len=0)
result2 = wrapped(key="val2")
assert result2 == ((), {"key": "val2"})

# 3rd call: hits cache for len=0
result3 = wrapped(key="val3")
assert result3 == ((), {"key": "val3"})

def test_named_params_before_var_positional(self):
"""Named params before *args are filled correctly on cache hit."""

def fn(a, b, *args, **kwargs):
return a, b, args, kwargs

wrapped = callable_method(fn)

result1 = wrapped(1, 2, 3, 4, key="val")
assert result1 == (1, 2, (3, 4), {"key": "val"})

result2 = wrapped(10, 20, 30, key="val2")
assert result2 == (10, 20, (30,), {"key": "val2"})

def test_kwargs_wins_with_var_positional_present(self):
"""POSITIONAL_OR_KEYWORD before *args prefers kwargs when name matches."""

def fn(event, *args, **kwargs):
return event, args, kwargs

wrapped = callable_method(fn)

# 1st call: 'event' in both positional and kwargs — kwargs wins
result1 = wrapped("discarded", "extra", event="ev1", key="a")
assert result1 == ("ev1", ("extra",), {"key": "a"})

# 2nd call: cache hit, same behavior
result2 = wrapped("other", "more", event="ev2", key="b")
assert result2 == ("ev2", ("more",), {"key": "b"})
Loading