Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
exclude: docs/auto_examples
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: v0.8.1
rev: v0.15.0
hooks:
# Run the linter.
- id: ruff
Expand Down
70 changes: 62 additions & 8 deletions statemachine/spec_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import operator
import re
from functools import reduce
from inspect import isawaitable
from typing import Callable

replacements = {"!": "not ", "^": " and ", "v": " or "}
Expand Down Expand Up @@ -33,8 +34,15 @@ def match_func(match):


def custom_not(predicate: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return not predicate(*args, **kwargs)
def decorated(*args, **kwargs):
result = predicate(*args, **kwargs)
if isawaitable(result):

async def _negate():
return not await result

return _negate()
return not result

decorated.__name__ = f"not({predicate.__name__})"
unique_key = getattr(predicate, "unique_key", "")
Expand All @@ -43,17 +51,53 @@ def decorated(*args, **kwargs) -> bool:


def custom_and(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return]
def decorated(*args, **kwargs):
left_result = left(*args, **kwargs)
if isawaitable(left_result):

async def _async_and():
lr = await left_result
if not lr:
return lr
rr = right(*args, **kwargs)
if isawaitable(rr):
return await rr
return rr

return _async_and()
if not left_result:
return left_result
right_result = right(*args, **kwargs)
if isawaitable(right_result):
return right_result
return right_result

decorated.__name__ = f"({left.__name__} and {right.__name__})"
decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined]
return decorated


def custom_or(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
def decorated(*args, **kwargs):
left_result = left(*args, **kwargs)
if isawaitable(left_result):

async def _async_or():
lr = await left_result
if lr:
return lr
rr = right(*args, **kwargs)
if isawaitable(rr):
return await rr
return rr

return _async_or()
if left_result:
return left_result
right_result = right(*args, **kwargs)
if isawaitable(right_result):
return right_result
return right_result

decorated.__name__ = f"({left.__name__} or {right.__name__})"
decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined]
Expand All @@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable:
operator_repr = comparison_repr[operator]

def custom_comparator(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return bool(operator(left(*args, **kwargs), right(*args, **kwargs)))
def decorated(*args, **kwargs):
left_result = left(*args, **kwargs)
right_result = right(*args, **kwargs)
if isawaitable(left_result) or isawaitable(right_result):

async def _async_compare():
lr = (await left_result) if isawaitable(left_result) else left_result
rr = (await right_result) if isawaitable(right_result) else right_result
return bool(operator(lr, rr))

return _async_compare()
return bool(operator(left_result, right_result))

decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})"
decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined]
Expand Down
1 change: 1 addition & 0 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __setstate__(self, state):
self._register_callbacks([])
self.add_listener(*listeners.keys())
self._engine = self._get_engine(rtc)
self._engine.start()

def _get_initial_state(self):
initial_state_value = self.start_value if self.start_value else self.initial_state.value
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/user_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class UserStatusMachine(StateMachine):
def on_signup(self, token: str):
if token == "":
raise ValueError("Token is required")
self.model.verified = True
self.model.verified = True # type: ignore[union-attr]


class UserExperienceMachine(StateMachine):
Expand Down
80 changes: 80 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,86 @@ def test_async_state_from_sync_context(async_order_control_machine):
assert sm.completed.is_active


class AsyncConditionExpressionMachine(StateMachine):
"""Regression test for issue #535: async conditions in boolean expressions."""

s1 = State(initial=True)

go_not = s1.to.itself(cond="not cond_false")
go_and = s1.to.itself(cond="cond_true and cond_true")
go_or_false_first = s1.to.itself(cond="cond_false or cond_true")
go_or_true_first = s1.to.itself(cond="cond_true or cond_false")
go_blocked = s1.to.itself(cond="not cond_true")
go_and_blocked = s1.to.itself(cond="cond_true and cond_false")
go_or_both_false = s1.to.itself(cond="cond_false or cond_false")

async def cond_true(self):
return True

async def cond_false(self):
return False

async def on_enter_state(self, target):
"""Async callback to ensure the SM uses AsyncEngine."""


async def test_async_condition_not(recwarn):
"""Issue #535: 'not cond_false' should allow the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
await sm.go_not()
assert sm.s1.is_active
assert not any("coroutine" in str(w.message) for w in recwarn.list)


async def test_async_condition_not_blocked():
"""Issue #535: 'not cond_true' should block the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
with pytest.raises(sm.TransitionNotAllowed):
await sm.go_blocked()


async def test_async_condition_and():
"""Issue #535: 'cond_true and cond_true' should allow the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
await sm.go_and()
assert sm.s1.is_active


async def test_async_condition_and_blocked():
"""Issue #535: 'cond_true and cond_false' should block the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
with pytest.raises(sm.TransitionNotAllowed):
await sm.go_and_blocked()


async def test_async_condition_or_false_first():
"""Issue #535: 'cond_false or cond_true' should allow the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
await sm.go_or_false_first()
assert sm.s1.is_active


async def test_async_condition_or_true_first():
"""'cond_true or cond_false' should allow the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
await sm.go_or_true_first()
assert sm.s1.is_active


async def test_async_condition_or_both_false():
"""'cond_false or cond_false' should block the transition."""
sm = AsyncConditionExpressionMachine()
await sm.activate_initial_state()
with pytest.raises(sm.TransitionNotAllowed):
await sm.go_or_both_false()


async def test_async_state_should_be_initialized(async_order_control_machine):
"""Test that the state machine is initialized before any event is triggered

Expand Down
49 changes: 49 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import pickle
from copy import deepcopy
Expand Down Expand Up @@ -181,3 +182,51 @@ def test_copy_with_custom_init_and_vars(copy_method):
assert sm2.custom == 1
assert sm2.value == [1, 2, 3]
assert sm2.current_state == MyStateMachine.started


class AsyncTrafficLightMachine(StateMachine):
green = State(initial=True)
yellow = State()
red = State()

cycle = green.to(yellow) | yellow.to(red) | red.to(green)

async def on_enter_state(self, target):
"""Async callback to ensure the SM uses AsyncEngine."""


def test_copy_async_statemachine_before_activation(copy_method):
"""Regression test for issue #544: async SM fails after pickle/deepcopy.

When an async SM is copied before activation, the copy must still be
activatable because ``__setstate__`` re-enqueues the ``__initial__`` event.
"""
sm = AsyncTrafficLightMachine()
sm_copy = copy_method(sm)

async def verify():
await sm_copy.activate_initial_state()
assert sm_copy.current_state == AsyncTrafficLightMachine.green
await sm_copy.cycle()
assert sm_copy.current_state == AsyncTrafficLightMachine.yellow

asyncio.run(verify())


def test_copy_async_statemachine_after_activation(copy_method):
"""Copying an async SM that is already activated preserves its current state."""

async def setup_and_verify():
sm = AsyncTrafficLightMachine()
await sm.activate_initial_state()
await sm.cycle()
assert sm.current_state == AsyncTrafficLightMachine.yellow

sm_copy = copy_method(sm)

await sm_copy.activate_initial_state()
assert sm_copy.current_state == AsyncTrafficLightMachine.yellow
await sm_copy.cycle()
assert sm_copy.current_state == AsyncTrafficLightMachine.red

asyncio.run(setup_and_verify())
1 change: 0 additions & 1 deletion tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from functools import partial

import pytest

from statemachine.dispatcher import callable_method
from statemachine.signature import SignatureAdapter

Expand Down
96 changes: 96 additions & 0 deletions tests/test_spec_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging

import pytest
Expand Down Expand Up @@ -247,6 +248,101 @@ def variable_hook(var_name):
("height > 1 and height < 2", True, ["height"]),
],
)
def async_variable_hook(var_name):
"""Variable hook that returns async callables, for testing issue #535."""
values = {
"cond_true": True,
"cond_false": False,
"val_10": 10,
"val_20": 20,
}

value = values.get(var_name, False)

async def decorated(*args, **kwargs):
await asyncio.sleep(0)
return value

decorated.__name__ = var_name
return decorated


@pytest.mark.parametrize(
("expression", "expected"),
[
("not cond_false", True),
("not cond_true", False),
("cond_true and cond_true", True),
("cond_true and cond_false", False),
("cond_false and cond_true", False),
("cond_false or cond_true", True),
("cond_true or cond_false", True),
("cond_false or cond_false", False),
("not cond_false and cond_true", True),
("not (cond_true and cond_false)", True),
("not (cond_false or cond_false)", True),
("cond_true and not cond_false", True),
("val_10 == 10", True),
("val_10 != 20", True),
("val_10 < val_20", True),
("val_20 > val_10", True),
("val_10 >= 10", True),
("val_10 <= val_20", True),
],
)
def test_async_expressions(expression, expected):
"""Issue #535: condition expressions with async predicates must await results."""
parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping)
result = parsed_expr()
assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}"
assert asyncio.run(result) is expected, expression


def mixed_variable_hook(var_name):
"""Variable hook where some vars are sync and some are async."""
sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10}
async_values = {"async_true": True, "async_false": False, "async_20": 20}

if var_name in async_values:
value = async_values[var_name]

async def async_decorated(*args, **kwargs):
await asyncio.sleep(0)
return value

async_decorated.__name__ = var_name
return async_decorated

def sync_decorated(*args, **kwargs):
return sync_values.get(var_name, False)

sync_decorated.__name__ = var_name
return sync_decorated


@pytest.mark.parametrize(
("expression", "expected"),
[
# async left, sync right
("async_true and sync_true", True),
("async_false or sync_true", True),
# sync left, async right
("sync_true and async_true", True),
("sync_false or async_true", True),
("sync_true and async_false", False),
("sync_false or async_false", False),
],
)
def test_mixed_sync_async_expressions(expression, expected):
"""Expressions mixing sync and async predicates must handle both correctly."""
parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping)
result = parsed_expr()
if asyncio.iscoroutine(result):
assert asyncio.run(result) is expected, expression
else:
assert result is expected, expression


@pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once")
def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called):
caplog.set_level(logging.DEBUG, logger="tests")
Expand Down
Loading