diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 378b6ebd..b39fb525 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: exclude: docs/auto_examples - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: v0.8.1 + rev: v0.15.0 hooks: # Run the linter. - id: ruff diff --git a/docs/guards.md b/docs/guards.md index 65eaf9e6..84140657 100644 --- a/docs/guards.md +++ b/docs/guards.md @@ -47,7 +47,7 @@ To control the evaluation order, declare transitions in the desired order: ```python # Declare in the order you want them checked: first = state_a.to(state_b, cond="check1") # Checked FIRST -second = state_a.to(state_c, cond="check2") # Checked SECOND +second = state_a.to(state_c, cond="check2") # Checked SECOND third = state_a.to(state_d, cond="check3") # Checked THIRD my_event = first | second | third # Order matches declaration @@ -159,6 +159,79 @@ So, a condition `s1.to(s2, cond=lambda: [])` will evaluate as `False`, as an emp **falsy** value. ``` +### Checking enabled events + +The {ref}`StateMachine.allowed_events` property returns events reachable from the current state, +but it does **not** evaluate `cond`/`unless` guards. To check which events actually have their +conditions satisfied, use {ref}`StateMachine.enabled_events`. + +```{testsetup} + +>>> from statemachine import StateMachine, State + +``` + +```py +>>> class ApprovalMachine(StateMachine): +... pending = State(initial=True) +... approved = State(final=True) +... rejected = State(final=True) +... +... approve = pending.to(approved, cond="is_manager") +... reject = pending.to(rejected) +... +... is_manager = False + +>>> sm = ApprovalMachine() + +>>> [e.id for e in sm.allowed_events] +['approve', 'reject'] + +>>> [e.id for e in sm.enabled_events()] +['reject'] + +>>> sm.is_manager = True + +>>> [e.id for e in sm.enabled_events()] +['approve', 'reject'] + +``` + +`enabled_events` is a method (not a property) because conditions may depend on runtime +arguments. Any `*args`/`**kwargs` passed to `enabled_events()` are forwarded to the +condition callbacks, just like when triggering an event: + +```py +>>> class TaskMachine(StateMachine): +... idle = State(initial=True) +... running = State(final=True) +... +... start = idle.to(running, cond="has_enough_resources") +... +... def has_enough_resources(self, cpu=0): +... return cpu >= 4 + +>>> sm = TaskMachine() + +>>> sm.enabled_events() +[] + +>>> [e.id for e in sm.enabled_events(cpu=8)] +['start'] + +``` + +```{tip} +This is useful for UI scenarios where you want to show or hide buttons based on whether +an event's conditions are currently satisfied. +``` + +```{note} +An event is considered **enabled** if at least one of its transitions from the current state +has all conditions satisfied. If a condition raises an exception, the event is treated as +enabled (permissive behavior). +``` + ## Validators diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 9d2b3f9f..ccc88496 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -97,6 +97,34 @@ async def _trigger(self, trigger_data: TriggerData): return result if executed else None + async def enabled_events(self, *args, **kwargs): + sm = self.sm + enabled = {} + for transition in sm.current_state.transitions: + for event in transition.events: + if event in enabled: + continue + extended_kwargs = kwargs.copy() + extended_kwargs.update( + { + "machine": sm, + "model": sm.model, + "event": getattr(sm, event), + "source": transition.source, + "target": transition.target, + "state": sm.current_state, + "transition": transition, + } + ) + try: + if await sm._callbacks.async_all( + transition.cond.key, *args, **extended_kwargs + ): + enabled[event] = getattr(sm, event) + except Exception: + enabled[event] = getattr(sm, event) + return list(enabled.values()) + async def _activate(self, trigger_data: TriggerData, transition: "Transition"): event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index 4400cd08..d65ef119 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -99,6 +99,32 @@ def _trigger(self, trigger_data: TriggerData): return result if executed else None + def enabled_events(self, *args, **kwargs): + sm = self.sm + enabled = {} + for transition in sm.current_state.transitions: + for event in transition.events: + if event in enabled: + continue + extended_kwargs = kwargs.copy() + extended_kwargs.update( + { + "machine": sm, + "model": sm.model, + "event": getattr(sm, event), + "source": transition.source, + "target": transition.target, + "state": sm.current_state, + "transition": transition, + } + ) + try: + if sm._callbacks.all(transition.cond.key, *args, **extended_kwargs): + enabled[event] = getattr(sm, event) + except Exception: + enabled[event] = getattr(sm, event) + return list(enabled.values()) + def _activate(self, trigger_data: TriggerData, transition: "Transition"): event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs diff --git a/statemachine/spec_parser.py b/statemachine/spec_parser.py index 7596c083..1899d0a4 100644 --- a/statemachine/spec_parser.py +++ b/statemachine/spec_parser.py @@ -2,6 +2,7 @@ import operator import re from functools import reduce +from inspect import isawaitable from typing import Callable replacements = {"!": "not ", "^": " and ", "v": " or "} @@ -33,8 +34,15 @@ def match_func(match): def custom_not(predicate: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return not predicate(*args, **kwargs) + def decorated(*args, **kwargs): + result = predicate(*args, **kwargs) + if isawaitable(result): + + async def _negate(): + return not await result + + return _negate() + return not result decorated.__name__ = f"not({predicate.__name__})" unique_key = getattr(predicate, "unique_key", "") @@ -43,8 +51,26 @@ def decorated(*args, **kwargs) -> bool: def custom_and(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return] + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + if isawaitable(left_result): + + async def _async_and(): + lr = await left_result + if not lr: + return lr + rr = right(*args, **kwargs) + if isawaitable(rr): + return await rr + return rr + + return _async_and() + if not left_result: + return left_result + right_result = right(*args, **kwargs) + if isawaitable(right_result): + return right_result + return right_result decorated.__name__ = f"({left.__name__} and {right.__name__})" decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined] @@ -52,8 +78,26 @@ def decorated(*args, **kwargs) -> bool: def custom_or(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return] + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + if isawaitable(left_result): + + async def _async_or(): + lr = await left_result + if lr: + return lr + rr = right(*args, **kwargs) + if isawaitable(rr): + return await rr + return rr + + return _async_or() + if left_result: + return left_result + right_result = right(*args, **kwargs) + if isawaitable(right_result): + return right_result + return right_result decorated.__name__ = f"({left.__name__} or {right.__name__})" decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined] @@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable: operator_repr = comparison_repr[operator] def custom_comparator(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return bool(operator(left(*args, **kwargs), right(*args, **kwargs))) + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + right_result = right(*args, **kwargs) + if isawaitable(left_result) or isawaitable(right_result): + + async def _async_compare(): + lr = (await left_result) if isawaitable(left_result) else left_result + rr = (await right_result) if isawaitable(right_result) else right_result + return bool(operator(lr, rr)) + + return _async_compare() + return bool(operator(left_result, right_result)) decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})" decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined] diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index ac402159..3e14edaf 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -147,6 +147,7 @@ def __setstate__(self, state): self._register_callbacks([]) self.add_listener(*listeners.keys()) self._engine = self._get_engine(rtc) + self._engine.start() def _get_initial_state(self): initial_state_value = self.start_value if self.start_value else self.initial_state.value @@ -294,6 +295,24 @@ def allowed_events(self) -> "List[Event]": """List of the current allowed events.""" return [getattr(self, event) for event in self.current_state.transitions.unique_events] + def enabled_events(self, *args, **kwargs): + """List of the current enabled events, considering guard conditions. + + An event is **enabled** if at least one of its transitions from the current + state has all ``cond``/``unless`` guards satisfied. + + Args: + *args: Positional arguments forwarded to condition callbacks. + **kwargs: Keyword arguments forwarded to condition callbacks. + + Returns: + A list of enabled :ref:`Event` instances. + """ + result = self._engine.enabled_events(*args, **kwargs) + if not isawaitable(result): + return result + return run_async_from_sync(result) + def _put_nonblocking(self, trigger_data: TriggerData): """Put the trigger on the queue without blocking the caller.""" self._engine.put(trigger_data) diff --git a/tests/examples/user_machine.py b/tests/examples/user_machine.py index a9fcb193..ad0320a8 100644 --- a/tests/examples/user_machine.py +++ b/tests/examples/user_machine.py @@ -88,7 +88,7 @@ class UserStatusMachine(StateMachine): def on_signup(self, token: str): if token == "": raise ValueError("Token is required") - self.model.verified = True + self.model.verified = True # type: ignore[union-attr] class UserExperienceMachine(StateMachine): diff --git a/tests/test_async.py b/tests/test_async.py index 36af9126..7a995e88 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -96,6 +96,86 @@ def test_async_state_from_sync_context(async_order_control_machine): assert sm.completed.is_active +class AsyncConditionExpressionMachine(StateMachine): + """Regression test for issue #535: async conditions in boolean expressions.""" + + s1 = State(initial=True) + + go_not = s1.to.itself(cond="not cond_false") + go_and = s1.to.itself(cond="cond_true and cond_true") + go_or_false_first = s1.to.itself(cond="cond_false or cond_true") + go_or_true_first = s1.to.itself(cond="cond_true or cond_false") + go_blocked = s1.to.itself(cond="not cond_true") + go_and_blocked = s1.to.itself(cond="cond_true and cond_false") + go_or_both_false = s1.to.itself(cond="cond_false or cond_false") + + async def cond_true(self): + return True + + async def cond_false(self): + return False + + async def on_enter_state(self, target): + """Async callback to ensure the SM uses AsyncEngine.""" + + +async def test_async_condition_not(recwarn): + """Issue #535: 'not cond_false' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_not() + assert sm.s1.is_active + assert not any("coroutine" in str(w.message) for w in recwarn.list) + + +async def test_async_condition_not_blocked(): + """Issue #535: 'not cond_true' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_blocked() + + +async def test_async_condition_and(): + """Issue #535: 'cond_true and cond_true' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_and() + assert sm.s1.is_active + + +async def test_async_condition_and_blocked(): + """Issue #535: 'cond_true and cond_false' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_and_blocked() + + +async def test_async_condition_or_false_first(): + """Issue #535: 'cond_false or cond_true' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_or_false_first() + assert sm.s1.is_active + + +async def test_async_condition_or_true_first(): + """'cond_true or cond_false' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_or_true_first() + assert sm.s1.is_active + + +async def test_async_condition_or_both_false(): + """'cond_false or cond_false' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_or_both_false() + + async def test_async_state_should_be_initialized(async_order_control_machine): """Test that the state machine is initialized before any event is triggered @@ -119,3 +199,81 @@ async def test_async_state_should_be_initialized(async_order_control_machine): await sm.activate_initial_state() assert sm.current_state == sm.waiting_for_payment + + +class TestAsyncEnabledEvents: + async def test_passing_async_condition(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + async def is_ready(self): + return True + + sm = MyMachine() + await sm.activate_initial_state() + assert [e.id for e in await sm.enabled_events()] == ["go"] + + async def test_failing_async_condition(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + async def is_ready(self): + return False + + sm = MyMachine() + await sm.activate_initial_state() + assert await sm.enabled_events() == [] + + async def test_kwargs_forwarded_to_async_conditions(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="check_value") + + async def check_value(self, value=0): + return value > 10 + + sm = MyMachine() + await sm.activate_initial_state() + assert await sm.enabled_events() == [] + assert [e.id for e in await sm.enabled_events(value=20)] == ["go"] + + async def test_async_condition_exception_treated_as_enabled(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="bad_cond") + + async def bad_cond(self): + raise RuntimeError("boom") + + sm = MyMachine() + await sm.activate_initial_state() + assert [e.id for e in await sm.enabled_events()] == ["go"] + + async def test_mixed_enabled_and_disabled_async(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State() + s2 = State(final=True) + + go = s0.to(s1, cond="cond_true") + stop = s0.to(s2, cond="cond_false") + + async def cond_true(self): + return True + + async def cond_false(self): + return False + + sm = MyMachine() + await sm.activate_initial_state() + assert [e.id for e in await sm.enabled_events()] == ["go"] diff --git a/tests/test_copy.py b/tests/test_copy.py index 15e2c358..b2af2819 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,3 +1,4 @@ +import asyncio import logging import pickle from copy import deepcopy @@ -181,3 +182,51 @@ def test_copy_with_custom_init_and_vars(copy_method): assert sm2.custom == 1 assert sm2.value == [1, 2, 3] assert sm2.current_state == MyStateMachine.started + + +class AsyncTrafficLightMachine(StateMachine): + green = State(initial=True) + yellow = State() + red = State() + + cycle = green.to(yellow) | yellow.to(red) | red.to(green) + + async def on_enter_state(self, target): + """Async callback to ensure the SM uses AsyncEngine.""" + + +def test_copy_async_statemachine_before_activation(copy_method): + """Regression test for issue #544: async SM fails after pickle/deepcopy. + + When an async SM is copied before activation, the copy must still be + activatable because ``__setstate__`` re-enqueues the ``__initial__`` event. + """ + sm = AsyncTrafficLightMachine() + sm_copy = copy_method(sm) + + async def verify(): + await sm_copy.activate_initial_state() + assert sm_copy.current_state == AsyncTrafficLightMachine.green + await sm_copy.cycle() + assert sm_copy.current_state == AsyncTrafficLightMachine.yellow + + asyncio.run(verify()) + + +def test_copy_async_statemachine_after_activation(copy_method): + """Copying an async SM that is already activated preserves its current state.""" + + async def setup_and_verify(): + sm = AsyncTrafficLightMachine() + await sm.activate_initial_state() + await sm.cycle() + assert sm.current_state == AsyncTrafficLightMachine.yellow + + sm_copy = copy_method(sm) + + await sm_copy.activate_initial_state() + assert sm_copy.current_state == AsyncTrafficLightMachine.yellow + await sm_copy.cycle() + assert sm_copy.current_state == AsyncTrafficLightMachine.red + + asyncio.run(setup_and_verify()) diff --git a/tests/test_signature.py b/tests/test_signature.py index 1cb68673..ccf30232 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -2,7 +2,6 @@ from functools import partial import pytest - from statemachine.dispatcher import callable_method from statemachine.signature import SignatureAdapter diff --git a/tests/test_spec_parser.py b/tests/test_spec_parser.py index ace3a4b1..569090d9 100644 --- a/tests/test_spec_parser.py +++ b/tests/test_spec_parser.py @@ -1,3 +1,4 @@ +import asyncio import logging import pytest @@ -247,6 +248,101 @@ def variable_hook(var_name): ("height > 1 and height < 2", True, ["height"]), ], ) +def async_variable_hook(var_name): + """Variable hook that returns async callables, for testing issue #535.""" + values = { + "cond_true": True, + "cond_false": False, + "val_10": 10, + "val_20": 20, + } + + value = values.get(var_name, False) + + async def decorated(*args, **kwargs): + await asyncio.sleep(0) + return value + + decorated.__name__ = var_name + return decorated + + +@pytest.mark.parametrize( + ("expression", "expected"), + [ + ("not cond_false", True), + ("not cond_true", False), + ("cond_true and cond_true", True), + ("cond_true and cond_false", False), + ("cond_false and cond_true", False), + ("cond_false or cond_true", True), + ("cond_true or cond_false", True), + ("cond_false or cond_false", False), + ("not cond_false and cond_true", True), + ("not (cond_true and cond_false)", True), + ("not (cond_false or cond_false)", True), + ("cond_true and not cond_false", True), + ("val_10 == 10", True), + ("val_10 != 20", True), + ("val_10 < val_20", True), + ("val_20 > val_10", True), + ("val_10 >= 10", True), + ("val_10 <= val_20", True), + ], +) +def test_async_expressions(expression, expected): + """Issue #535: condition expressions with async predicates must await results.""" + parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping) + result = parsed_expr() + assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}" + assert asyncio.run(result) is expected, expression + + +def mixed_variable_hook(var_name): + """Variable hook where some vars are sync and some are async.""" + sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10} + async_values = {"async_true": True, "async_false": False, "async_20": 20} + + if var_name in async_values: + value = async_values[var_name] + + async def async_decorated(*args, **kwargs): + await asyncio.sleep(0) + return value + + async_decorated.__name__ = var_name + return async_decorated + + def sync_decorated(*args, **kwargs): + return sync_values.get(var_name, False) + + sync_decorated.__name__ = var_name + return sync_decorated + + +@pytest.mark.parametrize( + ("expression", "expected"), + [ + # async left, sync right + ("async_true and sync_true", True), + ("async_false or sync_true", True), + # sync left, async right + ("sync_true and async_true", True), + ("sync_false or async_true", True), + ("sync_true and async_false", False), + ("sync_false or async_false", False), + ], +) +def test_mixed_sync_async_expressions(expression, expected): + """Expressions mixing sync and async predicates must handle both correctly.""" + parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping) + result = parsed_expr() + if asyncio.iscoroutine(result): + assert asyncio.run(result) is expected, expression + else: + assert result is expected, expression + + @pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once") def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called): caplog.set_level(logging.DEBUG, logger="tests") diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py index 9ff82897..ea1531f7 100644 --- a/tests/test_statemachine.py +++ b/tests/test_statemachine.py @@ -503,3 +503,134 @@ def __bool__(self): machine.produce() assert model.state == "producing" + + +class TestEnabledEvents: + def test_no_conditions_same_as_allowed_events(self, campaign_machine): + """Without conditions, enabled_events should match allowed_events.""" + sm = campaign_machine() + assert [e.id for e in sm.enabled_events()] == [e.id for e in sm.allowed_events] + + def test_passing_condition_returns_event(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + def is_ready(self): + return True + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_failing_condition_excludes_event(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="is_ready") + + def is_ready(self): + return False + + sm = MyMachine() + assert sm.enabled_events() == [] + + def test_multiple_transitions_one_passes(self): + """Same event with multiple transitions: included if at least one passes.""" + + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State() + s2 = State(final=True) + + go = s0.to(s1, cond="cond_false") | s0.to(s2, cond="cond_true") + + def cond_false(self): + return False + + def cond_true(self): + return True + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_final_state_returns_empty(self, campaign_machine): + sm = campaign_machine() + sm.produce() + sm.deliver() + assert sm.enabled_events() == [] + + def test_kwargs_forwarded_to_conditions(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="check_value") + + def check_value(self, value=0): + return value > 10 + + sm = MyMachine() + assert sm.enabled_events() == [] + assert [e.id for e in sm.enabled_events(value=20)] == ["go"] + + def test_condition_exception_treated_as_enabled(self): + """If a condition raises, the event is treated as enabled (permissive).""" + + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, cond="bad_cond") + + def bad_cond(self): + raise RuntimeError("boom") + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_mixed_enabled_and_disabled(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State() + s2 = State(final=True) + + go = s0.to(s1, cond="cond_true") + stop = s0.to(s2, cond="cond_false") + + def cond_true(self): + return True + + def cond_false(self): + return False + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"] + + def test_unless_condition(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, unless="is_blocked") + + def is_blocked(self): + return True + + sm = MyMachine() + assert sm.enabled_events() == [] + + def test_unless_condition_passes(self): + class MyMachine(StateMachine): + s0 = State(initial=True) + s1 = State(final=True) + + go = s0.to(s1, unless="is_blocked") + + def is_blocked(self): + return False + + sm = MyMachine() + assert [e.id for e in sm.enabled_events()] == ["go"]