diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 378b6ebd..b39fb525 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: exclude: docs/auto_examples - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: v0.8.1 + rev: v0.15.0 hooks: # Run the linter. - id: ruff diff --git a/statemachine/spec_parser.py b/statemachine/spec_parser.py index 7596c083..1899d0a4 100644 --- a/statemachine/spec_parser.py +++ b/statemachine/spec_parser.py @@ -2,6 +2,7 @@ import operator import re from functools import reduce +from inspect import isawaitable from typing import Callable replacements = {"!": "not ", "^": " and ", "v": " or "} @@ -33,8 +34,15 @@ def match_func(match): def custom_not(predicate: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return not predicate(*args, **kwargs) + def decorated(*args, **kwargs): + result = predicate(*args, **kwargs) + if isawaitable(result): + + async def _negate(): + return not await result + + return _negate() + return not result decorated.__name__ = f"not({predicate.__name__})" unique_key = getattr(predicate, "unique_key", "") @@ -43,8 +51,26 @@ def decorated(*args, **kwargs) -> bool: def custom_and(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return] + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + if isawaitable(left_result): + + async def _async_and(): + lr = await left_result + if not lr: + return lr + rr = right(*args, **kwargs) + if isawaitable(rr): + return await rr + return rr + + return _async_and() + if not left_result: + return left_result + right_result = right(*args, **kwargs) + if isawaitable(right_result): + return right_result + return right_result decorated.__name__ = f"({left.__name__} and {right.__name__})" decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined] @@ -52,8 +78,26 @@ def decorated(*args, **kwargs) -> bool: def custom_or(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return] + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + if isawaitable(left_result): + + async def _async_or(): + lr = await left_result + if lr: + return lr + rr = right(*args, **kwargs) + if isawaitable(rr): + return await rr + return rr + + return _async_or() + if left_result: + return left_result + right_result = right(*args, **kwargs) + if isawaitable(right_result): + return right_result + return right_result decorated.__name__ = f"({left.__name__} or {right.__name__})" decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined] @@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable: operator_repr = comparison_repr[operator] def custom_comparator(left: Callable, right: Callable) -> Callable: - def decorated(*args, **kwargs) -> bool: - return bool(operator(left(*args, **kwargs), right(*args, **kwargs))) + def decorated(*args, **kwargs): + left_result = left(*args, **kwargs) + right_result = right(*args, **kwargs) + if isawaitable(left_result) or isawaitable(right_result): + + async def _async_compare(): + lr = (await left_result) if isawaitable(left_result) else left_result + rr = (await right_result) if isawaitable(right_result) else right_result + return bool(operator(lr, rr)) + + return _async_compare() + return bool(operator(left_result, right_result)) decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})" decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined] diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index ac402159..e1e2a1e7 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -147,6 +147,7 @@ def __setstate__(self, state): self._register_callbacks([]) self.add_listener(*listeners.keys()) self._engine = self._get_engine(rtc) + self._engine.start() def _get_initial_state(self): initial_state_value = self.start_value if self.start_value else self.initial_state.value diff --git a/tests/examples/user_machine.py b/tests/examples/user_machine.py index a9fcb193..ad0320a8 100644 --- a/tests/examples/user_machine.py +++ b/tests/examples/user_machine.py @@ -88,7 +88,7 @@ class UserStatusMachine(StateMachine): def on_signup(self, token: str): if token == "": raise ValueError("Token is required") - self.model.verified = True + self.model.verified = True # type: ignore[union-attr] class UserExperienceMachine(StateMachine): diff --git a/tests/test_async.py b/tests/test_async.py index 36af9126..48267c08 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -96,6 +96,86 @@ def test_async_state_from_sync_context(async_order_control_machine): assert sm.completed.is_active +class AsyncConditionExpressionMachine(StateMachine): + """Regression test for issue #535: async conditions in boolean expressions.""" + + s1 = State(initial=True) + + go_not = s1.to.itself(cond="not cond_false") + go_and = s1.to.itself(cond="cond_true and cond_true") + go_or_false_first = s1.to.itself(cond="cond_false or cond_true") + go_or_true_first = s1.to.itself(cond="cond_true or cond_false") + go_blocked = s1.to.itself(cond="not cond_true") + go_and_blocked = s1.to.itself(cond="cond_true and cond_false") + go_or_both_false = s1.to.itself(cond="cond_false or cond_false") + + async def cond_true(self): + return True + + async def cond_false(self): + return False + + async def on_enter_state(self, target): + """Async callback to ensure the SM uses AsyncEngine.""" + + +async def test_async_condition_not(recwarn): + """Issue #535: 'not cond_false' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_not() + assert sm.s1.is_active + assert not any("coroutine" in str(w.message) for w in recwarn.list) + + +async def test_async_condition_not_blocked(): + """Issue #535: 'not cond_true' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_blocked() + + +async def test_async_condition_and(): + """Issue #535: 'cond_true and cond_true' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_and() + assert sm.s1.is_active + + +async def test_async_condition_and_blocked(): + """Issue #535: 'cond_true and cond_false' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_and_blocked() + + +async def test_async_condition_or_false_first(): + """Issue #535: 'cond_false or cond_true' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_or_false_first() + assert sm.s1.is_active + + +async def test_async_condition_or_true_first(): + """'cond_true or cond_false' should allow the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + await sm.go_or_true_first() + assert sm.s1.is_active + + +async def test_async_condition_or_both_false(): + """'cond_false or cond_false' should block the transition.""" + sm = AsyncConditionExpressionMachine() + await sm.activate_initial_state() + with pytest.raises(sm.TransitionNotAllowed): + await sm.go_or_both_false() + + async def test_async_state_should_be_initialized(async_order_control_machine): """Test that the state machine is initialized before any event is triggered diff --git a/tests/test_copy.py b/tests/test_copy.py index 15e2c358..b2af2819 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,3 +1,4 @@ +import asyncio import logging import pickle from copy import deepcopy @@ -181,3 +182,51 @@ def test_copy_with_custom_init_and_vars(copy_method): assert sm2.custom == 1 assert sm2.value == [1, 2, 3] assert sm2.current_state == MyStateMachine.started + + +class AsyncTrafficLightMachine(StateMachine): + green = State(initial=True) + yellow = State() + red = State() + + cycle = green.to(yellow) | yellow.to(red) | red.to(green) + + async def on_enter_state(self, target): + """Async callback to ensure the SM uses AsyncEngine.""" + + +def test_copy_async_statemachine_before_activation(copy_method): + """Regression test for issue #544: async SM fails after pickle/deepcopy. + + When an async SM is copied before activation, the copy must still be + activatable because ``__setstate__`` re-enqueues the ``__initial__`` event. + """ + sm = AsyncTrafficLightMachine() + sm_copy = copy_method(sm) + + async def verify(): + await sm_copy.activate_initial_state() + assert sm_copy.current_state == AsyncTrafficLightMachine.green + await sm_copy.cycle() + assert sm_copy.current_state == AsyncTrafficLightMachine.yellow + + asyncio.run(verify()) + + +def test_copy_async_statemachine_after_activation(copy_method): + """Copying an async SM that is already activated preserves its current state.""" + + async def setup_and_verify(): + sm = AsyncTrafficLightMachine() + await sm.activate_initial_state() + await sm.cycle() + assert sm.current_state == AsyncTrafficLightMachine.yellow + + sm_copy = copy_method(sm) + + await sm_copy.activate_initial_state() + assert sm_copy.current_state == AsyncTrafficLightMachine.yellow + await sm_copy.cycle() + assert sm_copy.current_state == AsyncTrafficLightMachine.red + + asyncio.run(setup_and_verify()) diff --git a/tests/test_signature.py b/tests/test_signature.py index 1cb68673..ccf30232 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -2,7 +2,6 @@ from functools import partial import pytest - from statemachine.dispatcher import callable_method from statemachine.signature import SignatureAdapter diff --git a/tests/test_spec_parser.py b/tests/test_spec_parser.py index ace3a4b1..569090d9 100644 --- a/tests/test_spec_parser.py +++ b/tests/test_spec_parser.py @@ -1,3 +1,4 @@ +import asyncio import logging import pytest @@ -247,6 +248,101 @@ def variable_hook(var_name): ("height > 1 and height < 2", True, ["height"]), ], ) +def async_variable_hook(var_name): + """Variable hook that returns async callables, for testing issue #535.""" + values = { + "cond_true": True, + "cond_false": False, + "val_10": 10, + "val_20": 20, + } + + value = values.get(var_name, False) + + async def decorated(*args, **kwargs): + await asyncio.sleep(0) + return value + + decorated.__name__ = var_name + return decorated + + +@pytest.mark.parametrize( + ("expression", "expected"), + [ + ("not cond_false", True), + ("not cond_true", False), + ("cond_true and cond_true", True), + ("cond_true and cond_false", False), + ("cond_false and cond_true", False), + ("cond_false or cond_true", True), + ("cond_true or cond_false", True), + ("cond_false or cond_false", False), + ("not cond_false and cond_true", True), + ("not (cond_true and cond_false)", True), + ("not (cond_false or cond_false)", True), + ("cond_true and not cond_false", True), + ("val_10 == 10", True), + ("val_10 != 20", True), + ("val_10 < val_20", True), + ("val_20 > val_10", True), + ("val_10 >= 10", True), + ("val_10 <= val_20", True), + ], +) +def test_async_expressions(expression, expected): + """Issue #535: condition expressions with async predicates must await results.""" + parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping) + result = parsed_expr() + assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}" + assert asyncio.run(result) is expected, expression + + +def mixed_variable_hook(var_name): + """Variable hook where some vars are sync and some are async.""" + sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10} + async_values = {"async_true": True, "async_false": False, "async_20": 20} + + if var_name in async_values: + value = async_values[var_name] + + async def async_decorated(*args, **kwargs): + await asyncio.sleep(0) + return value + + async_decorated.__name__ = var_name + return async_decorated + + def sync_decorated(*args, **kwargs): + return sync_values.get(var_name, False) + + sync_decorated.__name__ = var_name + return sync_decorated + + +@pytest.mark.parametrize( + ("expression", "expected"), + [ + # async left, sync right + ("async_true and sync_true", True), + ("async_false or sync_true", True), + # sync left, async right + ("sync_true and async_true", True), + ("sync_false or async_true", True), + ("sync_true and async_false", False), + ("sync_false or async_false", False), + ], +) +def test_mixed_sync_async_expressions(expression, expected): + """Expressions mixing sync and async predicates must handle both correctly.""" + parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping) + result = parsed_expr() + if asyncio.iscoroutine(result): + assert asyncio.run(result) is expected, expression + else: + assert result is expected, expression + + @pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once") def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called): caplog.set_level(logging.DEBUG, logger="tests")