From d5f134d6f0291d22bafac84f9fd8feda223a8775 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Fri, 13 Feb 2026 13:04:34 -0300 Subject: [PATCH 1/5] fix: re-enqueue initial event when deserializing async state machine (#544) When an async SM is pickled/deepcopied (e.g. via multiprocessing), the engine queue is not preserved. __setstate__ recreated the engine but never called start(), so the __initial__ event was never enqueued and activate_initial_state() would fail with InvalidStateValue. Closes #544 --- statemachine/statemachine.py | 1 + tests/examples/user_machine.py | 5 ++-- tests/test_copy.py | 53 ++++++++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index ac402159..e1e2a1e7 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -147,6 +147,7 @@ def __setstate__(self, state): self._register_callbacks([]) self.add_listener(*listeners.keys()) self._engine = self._get_engine(rtc) + self._engine.start() def _get_initial_state(self): initial_state_value = self.start_value if self.start_value else self.initial_state.value diff --git a/tests/examples/user_machine.py b/tests/examples/user_machine.py index a9fcb193..a3aa9fb3 100644 --- a/tests/examples/user_machine.py +++ b/tests/examples/user_machine.py @@ -14,10 +14,9 @@ from dataclasses import dataclass from enum import Enum -from statemachine.states import States - from statemachine import State from statemachine import StateMachine +from statemachine.states import States class UserStatus(str, Enum): @@ -88,7 +87,7 @@ class UserStatusMachine(StateMachine): def on_signup(self, token: str): if token == "": raise ValueError("Token is required") - self.model.verified = True + self.model.verified = True # type: ignore[union-attr] class UserExperienceMachine(StateMachine): diff --git a/tests/test_copy.py b/tests/test_copy.py index 15e2c358..54292505 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,3 +1,4 @@ +import asyncio import logging import pickle from copy import deepcopy @@ -5,11 +6,11 @@ from enum import auto import pytest -from statemachine.exceptions import TransitionNotAllowed -from statemachine.states import States from statemachine import State from statemachine import StateMachine +from statemachine.exceptions import TransitionNotAllowed +from statemachine.states import States logger = logging.getLogger(__name__) DEBUG = logging.DEBUG @@ -181,3 +182,51 @@ def test_copy_with_custom_init_and_vars(copy_method): assert sm2.custom == 1 assert sm2.value == [1, 2, 3] assert sm2.current_state == MyStateMachine.started + + +class AsyncTrafficLightMachine(StateMachine): + green = State(initial=True) + yellow = State() + red = State() + + cycle = green.to(yellow) | yellow.to(red) | red.to(green) + + async def on_enter_state(self, target): + pass + + +def test_copy_async_statemachine_before_activation(copy_method): + """Regression test for issue #544: async SM fails after pickle/deepcopy. + + When an async SM is copied before activation, the copy must still be + activatable because ``__setstate__`` re-enqueues the ``__initial__`` event. + """ + sm = AsyncTrafficLightMachine() + sm_copy = copy_method(sm) + + async def verify(): + await sm_copy.activate_initial_state() + assert sm_copy.current_state == AsyncTrafficLightMachine.green + await sm_copy.cycle() + assert sm_copy.current_state == AsyncTrafficLightMachine.yellow + + asyncio.run(verify()) + + +def test_copy_async_statemachine_after_activation(copy_method): + """Copying an async SM that is already activated preserves its current state.""" + + async def setup_and_verify(): + sm = AsyncTrafficLightMachine() + await sm.activate_initial_state() + await sm.cycle() + assert sm.current_state == AsyncTrafficLightMachine.yellow + + sm_copy = copy_method(sm) + + await sm_copy.activate_initial_state() + assert sm_copy.current_state == AsyncTrafficLightMachine.yellow + await sm_copy.cycle() + assert sm_copy.current_state == AsyncTrafficLightMachine.red + + asyncio.run(setup_and_verify()) From 95ac619fa41f80b2ba6afdc796c86879c3062ced Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Fri, 13 Feb 2026 13:18:21 -0300 Subject: [PATCH 2/5] fix: await async predicates in condition expressions (#535) The boolean expression combinators (custom_not, custom_and, custom_or, build_custom_operator) called predicates synchronously. When predicates were async, they returned unawaited coroutine objects which are always truthy, causing `not` to always return False, `and` to skip evaluation, and `or` to short-circuit incorrectly. Each combinator now checks `isawaitable()` on predicate results and returns a coroutine when needed, which CallbackWrapper.__call__ already knows how to await. Closes #535 --- statemachine/spec_parser.py | 70 ++++++++++++++++++++++++---- tests/test_async.py | 82 ++++++++++++++++++++++++++++++++- tests/test_spec_parser.py | 92 +++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 9 deletions(-) diff --git a/statemachine/spec_parser.py b/statemachine/spec_parser.py index 7596c083..1899d0a4 100644 --- a/statemachine/spec_parser.py +++ b/statemachine/spec_parser.py @@ -2,6 +2,7 @@ import operator import re from functools import reduce +from inspect import isawaitable from typing import Callable replacements = {"!": "not ", "^": " and ", "v": " or "} @@ -33,8 +34,15 @@ def match_func(match): def custom_not(predicate: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return not predicate(*args, **kwargs) + def decorated(*args, **kwargs): + result = predicate(*args, **kwargs) + if isawaitable(result): + + async def _negate(): + return not await result + + return _negate() + return not result decorated.__name__ = f"not({predicate.__name__})" unique_key = getattr(predicate, "unique_key", "") @@ -43,8 +51,26 @@ def decorated(*args, **kwargs) -> bool: def custom_and(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return] + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + if isawaitable(left_result): + + async def _async_and(): + lr = await left_result + if not lr: + return lr + rr = right(*args, **kwargs) + if isawaitable(rr): + return await rr + return rr + + return _async_and() + if not left_result: + return left_result + right_result = right(*args, **kwargs) + if isawaitable(right_result): + return right_result + return right_result decorated.__name__ = f"({left.__name__} and {right.__name__})" decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined] @@ -52,8 +78,26 @@ def decorated(*args, **kwargs) -> bool: def custom_or(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return] + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + if isawaitable(left_result): + + async def _async_or(): + lr = await left_result + if lr: + return lr + rr = right(*args, **kwargs) + if isawaitable(rr): + return await rr + return rr + + return _async_or() + if left_result: + return left_result + right_result = right(*args, **kwargs) + if isawaitable(right_result): + return right_result + return right_result decorated.__name__ = f"({left.__name__} or {right.__name__})" decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined] @@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable: operator_repr = comparison_repr[operator] def custom_comparator(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return bool(operator(left(*args, **kwargs), right(*args, **kwargs))) + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + right_result = right(*args, **kwargs) + if isawaitable(left_result) or isawaitable(right_result): + + async def _async_compare(): + lr = (await left_result) if isawaitable(left_result) else left_result + rr = (await right_result) if isawaitable(right_result) else right_result + return bool(operator(lr, rr)) + + return _async_compare() + return bool(operator(left_result, right_result)) decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})" decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined] diff --git a/tests/test_async.py b/tests/test_async.py index 36af9126..a6a85ff7 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,10 +1,10 @@ import re import pytest -from statemachine.exceptions import InvalidStateValue from statemachine import State from statemachine import StateMachine +from statemachine.exceptions import InvalidStateValue @pytest.fixture() @@ -96,6 +96,86 @@ def test_async_state_from_sync_context(async_order_control_machine): assert sm.completed.is_active +class AsyncConditionExpressionMachine(StateMachine): + """Regression test for issue #535: async conditions in boolean expressions.""" + + s1 = State(initial=True) + + go_not = s1.to.itself(cond="not cond_false") + go_and = s1.to.itself(cond="cond_true and cond_true") + go_or_false_first = s1.to.itself(cond="cond_false or cond_true") + go_or_true_first = s1.to.itself(cond="cond_true or cond_false") + go_blocked = s1.to.itself(cond="not cond_true") + go_and_blocked = s1.to.itself(cond="cond_true and cond_false") + go_or_both_false = s1.to.itself(cond="cond_false or cond_false") + + async def cond_true(self): + return True + + async def cond_false(self): + return False + + async def on_enter_state(self, target): + pass + + +async def test_async_condition_not(recwarn): + """Issue #535: 'not cond_false' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_not() + assert sm.s1.is_active + assert not any("coroutine" in str(w.message) for w in recwarn.list) + + +async def test_async_condition_not_blocked(): + """Issue #535: 'not cond_true' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_blocked() + + +async def test_async_condition_and(): + """Issue #535: 'cond_true and cond_true' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_and() + assert sm.s1.is_active + + +async def test_async_condition_and_blocked(): + """Issue #535: 'cond_true and cond_false' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_and_blocked() + + +async def test_async_condition_or_false_first(): + """Issue #535: 'cond_false or cond_true' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_or_false_first() + assert sm.s1.is_active + + +async def test_async_condition_or_true_first(): + """'cond_true or cond_false' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_or_true_first() + assert sm.s1.is_active + + +async def test_async_condition_or_both_false(): + """'cond_false or cond_false' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_or_both_false() + + async def test_async_state_should_be_initialized(async_order_control_machine): """Test that the state machine is initialized before any event is triggered diff --git a/tests/test_spec_parser.py b/tests/test_spec_parser.py index ace3a4b1..8fa1aae4 100644 --- a/tests/test_spec_parser.py +++ b/tests/test_spec_parser.py @@ -1,6 +1,8 @@ +import asyncio import logging import pytest + from statemachine.spec_parser import operator_mapping from statemachine.spec_parser import parse_boolean_expr @@ -247,6 +249,96 @@ def variable_hook(var_name): ("height > 1 and height < 2", True, ["height"]), ], ) +def async_variable_hook(var_name): + """Variable hook that returns async callables, for testing issue #535.""" + values = { + "cond_true": True, + "cond_false": False, + "val_10": 10, + "val_20": 20, + } + + async def decorated(*args, **kwargs): + return values.get(var_name, False) + + decorated.__name__ = var_name + return decorated + + +@pytest.mark.parametrize( + ("expression", "expected"), + [ + ("not cond_false", True), + ("not cond_true", False), + ("cond_true and cond_true", True), + ("cond_true and cond_false", False), + ("cond_false and cond_true", False), + ("cond_false or cond_true", True), + ("cond_true or cond_false", True), + ("cond_false or cond_false", False), + ("not cond_false and cond_true", True), + ("not (cond_true and cond_false)", True), + ("not (cond_false or cond_false)", True), + ("cond_true and not cond_false", True), + ("val_10 == 10", True), + ("val_10 != 20", True), + ("val_10 < val_20", True), + ("val_20 > val_10", True), + ("val_10 >= 10", True), + ("val_10 <= val_20", True), + ], +) +def test_async_expressions(expression, expected): + """Issue #535: condition expressions with async predicates must await results.""" + parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping) + result = parsed_expr() + assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}" + assert asyncio.run(result) is expected, expression + + +def mixed_variable_hook(var_name): + """Variable hook where some vars are sync and some are async.""" + sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10} + async_values = {"async_true": True, "async_false": False, "async_20": 20} + + if var_name in async_values: + + async def async_decorated(*args, **kwargs): + return async_values[var_name] + + async_decorated.__name__ = var_name + return async_decorated + + def sync_decorated(*args, **kwargs): + return sync_values.get(var_name, False) + + sync_decorated.__name__ = var_name + return sync_decorated + + +@pytest.mark.parametrize( + ("expression", "expected"), + [ + # async left, sync right + ("async_true and sync_true", True), + ("async_false or sync_true", True), + # sync left, async right + ("sync_true and async_true", True), + ("sync_false or async_true", True), + ("sync_true and async_false", False), + ("sync_false or async_false", False), + ], +) +def test_mixed_sync_async_expressions(expression, expected): + """Expressions mixing sync and async predicates must handle both correctly.""" + parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping) + result = parsed_expr() + if asyncio.iscoroutine(result): + assert asyncio.run(result) is expected, expression + else: + assert result is expected, expression + + @pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once") def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called): caplog.set_level(logging.DEBUG, logger="tests") From affcf67546b902cf8c341dfeb9dae5e348da1c00 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Fri, 13 Feb 2026 13:26:58 -0300 Subject: [PATCH 3/5] chore: sync pre-commit ruff rev with lockfile (v0.15.0) The pre-commit hook was using ruff v0.8.1 while the lockfile had v0.15.0, causing import sorting differences between local and CI. --- .pre-commit-config.yaml | 2 +- tests/examples/user_machine.py | 3 ++- tests/test_async.py | 2 +- tests/test_copy.py | 4 ++-- tests/test_signature.py | 1 - tests/test_spec_parser.py | 1 - 6 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 378b6ebd..b39fb525 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: exclude: docs/auto_examples - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: v0.8.1 + rev: v0.15.0 hooks: # Run the linter. - id: ruff diff --git a/tests/examples/user_machine.py b/tests/examples/user_machine.py index a3aa9fb3..ad0320a8 100644 --- a/tests/examples/user_machine.py +++ b/tests/examples/user_machine.py @@ -14,9 +14,10 @@ from dataclasses import dataclass from enum import Enum +from statemachine.states import States + from statemachine import State from statemachine import StateMachine -from statemachine.states import States class UserStatus(str, Enum): diff --git a/tests/test_async.py b/tests/test_async.py index a6a85ff7..93c2205b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,10 +1,10 @@ import re import pytest +from statemachine.exceptions import InvalidStateValue from statemachine import State from statemachine import StateMachine -from statemachine.exceptions import InvalidStateValue @pytest.fixture() diff --git a/tests/test_copy.py b/tests/test_copy.py index 54292505..9ccd5408 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -6,11 +6,11 @@ from enum import auto import pytest +from statemachine.exceptions import TransitionNotAllowed +from statemachine.states import States from statemachine import State from statemachine import StateMachine -from statemachine.exceptions import TransitionNotAllowed -from statemachine.states import States logger = logging.getLogger(__name__) DEBUG = logging.DEBUG diff --git a/tests/test_signature.py b/tests/test_signature.py index 1cb68673..ccf30232 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -2,7 +2,6 @@ from functools import partial import pytest - from statemachine.dispatcher import callable_method from statemachine.signature import SignatureAdapter diff --git a/tests/test_spec_parser.py b/tests/test_spec_parser.py index 8fa1aae4..056a647a 100644 --- a/tests/test_spec_parser.py +++ b/tests/test_spec_parser.py @@ -2,7 +2,6 @@ import logging import pytest - from statemachine.spec_parser import operator_mapping from statemachine.spec_parser import parse_boolean_expr From f428f9fdea4469fea718e06cbae22f1707393187 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Fri, 13 Feb 2026 13:30:26 -0300 Subject: [PATCH 4/5] fix: address SonarCloud code smells in tests - Add docstrings to empty async on_enter_state methods (S1186) - Use await asyncio.sleep(0) in async test hooks to satisfy S7503 --- tests/test_async.py | 2 +- tests/test_copy.py | 2 +- tests/test_spec_parser.py | 9 +++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_async.py b/tests/test_async.py index 93c2205b..48267c08 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -116,7 +116,7 @@ async def cond_false(self): return False async def on_enter_state(self, target): - pass + """Async callback to ensure the SM uses AsyncEngine.""" async def test_async_condition_not(recwarn): diff --git a/tests/test_copy.py b/tests/test_copy.py index 9ccd5408..b2af2819 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -192,7 +192,7 @@ class AsyncTrafficLightMachine(StateMachine): cycle = green.to(yellow) | yellow.to(red) | red.to(green) async def on_enter_state(self, target): - pass + """Async callback to ensure the SM uses AsyncEngine.""" def test_copy_async_statemachine_before_activation(copy_method): diff --git a/tests/test_spec_parser.py b/tests/test_spec_parser.py index 056a647a..569090d9 100644 --- a/tests/test_spec_parser.py +++ b/tests/test_spec_parser.py @@ -257,8 +257,11 @@ def async_variable_hook(var_name): "val_20": 20, } + value = values.get(var_name, False) + async def decorated(*args, **kwargs): - return values.get(var_name, False) + await asyncio.sleep(0) + return value decorated.__name__ = var_name return decorated @@ -301,9 +304,11 @@ def mixed_variable_hook(var_name): async_values = {"async_true": True, "async_false": False, "async_20": 20} if var_name in async_values: + value = async_values[var_name] async def async_decorated(*args, **kwargs): - return async_values[var_name] + await asyncio.sleep(0) + return value async_decorated.__name__ = var_name return async_decorated From cec1b9afe4de2d10039b2b7ac72a1263b5a38d84 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Fri, 13 Feb 2026 15:23:18 -0300 Subject: [PATCH 5/5] feat: add `enabled_events()` method to check guard conditions (#520) `allowed_events` returns events reachable from the current state but does not evaluate `cond`/`unless` guards. The new `enabled_events()` method evaluates conditions and returns only events that can actually fire. It accepts `*args`/`**kwargs` forwarded to condition callbacks, works with both sync and async engines, and treats condition exceptions as enabled (permissive behavior). Closes #520 --- docs/guards.md | 75 ++++++++++++++++++- statemachine/engines/async_.py | 28 +++++++ statemachine/engines/sync.py | 26 +++++++ statemachine/statemachine.py | 18 +++++ tests/test_async.py | 78 ++++++++++++++++++++ tests/test_statemachine.py | 131 +++++++++++++++++++++++++++++++++ 6 files changed, 355 insertions(+), 1 deletion(-) diff --git a/docs/guards.md b/docs/guards.md index 65eaf9e6..84140657 100644 --- a/docs/guards.md +++ b/docs/guards.md @@ -47,7 +47,7 @@ To control the evaluation order, declare transitions in the desired order: ```python # Declare in the order you want them checked: first = state_a.to(state_b, cond="check1") # Checked FIRST -second = state_a.to(state_c, cond="check2") # Checked SECOND +second = state_a.to(state_c, cond="check2") # Checked SECOND third = state_a.to(state_d, cond="check3") # Checked THIRD my_event = first | second | third # Order matches declaration @@ -159,6 +159,79 @@ So, a condition `s1.to(s2, cond=lambda: [])` will evaluate as `False`, as an emp **falsy** value. ``` +### Checking enabled events + +The {ref}`StateMachine.allowed_events` property returns events reachable from the current state, +but it does **not** evaluate `cond`/`unless` guards. To check which events actually have their +conditions satisfied, use {ref}`StateMachine.enabled_events`. + +```{testsetup} + +>>> from statemachine import StateMachine, State + +``` + +```py +>>> class ApprovalMachine(StateMachine): +... pending = State(initial=True) +... approved = State(final=True) +... rejected = State(final=True) +... +... approve = pending.to(approved, cond="is_manager") +... reject = pending.to(rejected) +... +... is_manager = False + +>>> sm = ApprovalMachine() + +>>> [e.id for e in sm.allowed_events] +['approve', 'reject'] + +>>> [e.id for e in sm.enabled_events()] +['reject'] + +>>> sm.is_manager = True + +>>> [e.id for e in sm.enabled_events()] +['approve', 'reject'] + +``` + +`enabled_events` is a method (not a property) because conditions may depend on runtime +arguments. Any `*args`/`**kwargs` passed to `enabled_events()` are forwarded to the +condition callbacks, just like when triggering an event: + +```py +>>> class TaskMachine(StateMachine): +... idle = State(initial=True) +... running = State(final=True) +... +... start = idle.to(running, cond="has_enough_resources") +... +... def has_enough_resources(self, cpu=0): +... return cpu >= 4 + +>>> sm = TaskMachine() + +>>> sm.enabled_events() +[] + +>>> [e.id for e in sm.enabled_events(cpu=8)] +['start'] + +``` + +```{tip} +This is useful for UI scenarios where you want to show or hide buttons based on whether +an event's conditions are currently satisfied. +``` + +```{note} +An event is considered **enabled** if at least one of its transitions from the current state +has all conditions satisfied. If a condition raises an exception, the event is treated as +enabled (permissive behavior). +``` + ## Validators diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 9d2b3f9f..ccc88496 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -97,6 +97,34 @@ async def _trigger(self, trigger_data: TriggerData): return result if executed else None + async def enabled_events(self, *args, **kwargs): + sm = self.sm + enabled = {} + for transition in sm.current_state.transitions: + for event in transition.events: + if event in enabled: + continue + extended_kwargs = kwargs.copy() + extended_kwargs.update( + { + "machine": sm, + "model": sm.model, + "event": getattr(sm, event), + "source": transition.source, + "target": transition.target, + "state": sm.current_state, + "transition": transition, + } + ) + try: + if await sm._callbacks.async_all( + transition.cond.key, *args, **extended_kwargs + ): + enabled[event] = getattr(sm, event) + except Exception: + enabled[event] = getattr(sm, event) + return list(enabled.values()) + async def _activate(self, trigger_data: TriggerData, transition: "Transition"): event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index 4400cd08..d65ef119 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -99,6 +99,32 @@ def _trigger(self, trigger_data: TriggerData): return result if executed else None + def enabled_events(self, *args, **kwargs): + sm = self.sm + enabled = {} + for transition in sm.current_state.transitions: + for event in transition.events: + if event in enabled: + continue + extended_kwargs = kwargs.copy() + extended_kwargs.update( + { + "machine": sm, + "model": sm.model, + "event": getattr(sm, event), + "source": transition.source, + "target": transition.target, + "state": sm.current_state, + "transition": transition, + } + ) + try: + if sm._callbacks.all(transition.cond.key, *args, **extended_kwargs): + enabled[event] = getattr(sm, event) + except Exception: + enabled[event] = getattr(sm, event) + return list(enabled.values()) + def _activate(self, trigger_data: TriggerData, transition: "Transition"): event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index e1e2a1e7..3e14edaf 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -295,6 +295,24 @@ def allowed_events(self) -> "List[Event]": """List of the current allowed events.""" return [getattr(self, event) for event in self.current_state.transitions.unique_events] + def enabled_events(self, *args, **kwargs): + """List of the current enabled events, considering guard conditions. + + An event is **enabled** if at least one of its transitions from the current + state has all ``cond``/``unless`` guards satisfied. + + Args: + *args: Positional arguments forwarded to condition callbacks. + **kwargs: Keyword arguments forwarded to condition callbacks. + + Returns: + A list of enabled :ref:`Event` instances. + """ + result = self._engine.enabled_events(*args, **kwargs) + if not isawaitable(result): + return result + return run_async_from_sync(result) + def _put_nonblocking(self, trigger_data: TriggerData): """Put the trigger on the queue without blocking the caller.""" self._engine.put(trigger_data) diff --git a/tests/test_async.py b/tests/test_async.py index 48267c08..7a995e88 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -199,3 +199,81 @@ async def test_async_state_should_be_initialized(async_order_control_machine): await sm.activate_initial_state() assert sm.current_state == sm.waiting_for_payment + + +class TestAsyncEnabledEvents: + async def test_passing_async_condition(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + async def is_ready(self): + return True + + sm = MyMachine() + await sm.activate_initial_state() + assert [e.id for e in await sm.enabled_events()] == ["go"] + + async def test_failing_async_condition(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + async def is_ready(self): + return False + + sm = MyMachine() + await sm.activate_initial_state() + assert await sm.enabled_events() == [] + + async def test_kwargs_forwarded_to_async_conditions(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="check_value") + + async def check_value(self, value=0): + return value > 10 + + sm = MyMachine() + await sm.activate_initial_state() + assert await sm.enabled_events() == [] + assert [e.id for e in await sm.enabled_events(value=20)] == ["go"] + + async def test_async_condition_exception_treated_as_enabled(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="bad_cond") + + async def bad_cond(self): + raise RuntimeError("boom") + + sm = MyMachine() + await sm.activate_initial_state() + assert [e.id for e in await sm.enabled_events()] == ["go"] + + async def test_mixed_enabled_and_disabled_async(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State() + s2 = State(final=True) + + go = s0.to(s1, cond="cond_true") + stop = s0.to(s2, cond="cond_false") + + async def cond_true(self): + return True + + async def cond_false(self): + return False + + sm = MyMachine() + await sm.activate_initial_state() + assert [e.id for e in await sm.enabled_events()] == ["go"] diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py index 9ff82897..ea1531f7 100644 --- a/tests/test_statemachine.py +++ b/tests/test_statemachine.py @@ -503,3 +503,134 @@ def __bool__(self): machine.produce() assert model.state == "producing" + + +class TestEnabledEvents: + def test_no_conditions_same_as_allowed_events(self, campaign_machine): + """Without conditions, enabled_events should match allowed_events.""" + sm = campaign_machine() + assert [e.id for e in sm.enabled_events()] == [e.id for e in sm.allowed_events] + + def test_passing_condition_returns_event(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + def is_ready(self): + return True + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_failing_condition_excludes_event(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + def is_ready(self): + return False + + sm = MyMachine() + assert sm.enabled_events() == [] + + def test_multiple_transitions_one_passes(self): + """Same event with multiple transitions: included if at least one passes.""" + + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State() + s2 = State(final=True) + + go = s0.to(s1, cond="cond_false") | s0.to(s2, cond="cond_true") + + def cond_false(self): + return False + + def cond_true(self): + return True + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_final_state_returns_empty(self, campaign_machine): + sm = campaign_machine() + sm.produce() + sm.deliver() + assert sm.enabled_events() == [] + + def test_kwargs_forwarded_to_conditions(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="check_value") + + def check_value(self, value=0): + return value > 10 + + sm = MyMachine() + assert sm.enabled_events() == [] + assert [e.id for e in sm.enabled_events(value=20)] == ["go"] + + def test_condition_exception_treated_as_enabled(self): + """If a condition raises, the event is treated as enabled (permissive).""" + + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="bad_cond") + + def bad_cond(self): + raise RuntimeError("boom") + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_mixed_enabled_and_disabled(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State() + s2 = State(final=True) + + go = s0.to(s1, cond="cond_true") + stop = s0.to(s2, cond="cond_false") + + def cond_true(self): + return True + + def cond_false(self): + return False + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_unless_condition(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, unless="is_blocked") + + def is_blocked(self): + return True + + sm = MyMachine() + assert sm.enabled_events() == [] + + def test_unless_condition_passes(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, unless="is_blocked") + + def is_blocked(self): + return False + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"]