Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
exclude: docs/auto_examples
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: v0.8.1
rev: v0.15.0
hooks:
# Run the linter.
- id: ruff
Expand Down
75 changes: 74 additions & 1 deletion docs/guards.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ To control the evaluation order, declare transitions in the desired order:
```python
# Declare in the order you want them checked:
first = state_a.to(state_b, cond="check1") # Checked FIRST
second = state_a.to(state_c, cond="check2") # Checked SECOND
second = state_a.to(state_c, cond="check2") # Checked SECOND
third = state_a.to(state_d, cond="check3") # Checked THIRD

my_event = first | second | third # Order matches declaration
Expand Down Expand Up @@ -159,6 +159,79 @@ So, a condition `s1.to(s2, cond=lambda: [])` will evaluate as `False`, as an emp
**falsy** value.
```

### Checking enabled events

The {ref}`StateMachine.allowed_events` property returns events reachable from the current state,
but it does **not** evaluate `cond`/`unless` guards. To check which events actually have their
conditions satisfied, use {ref}`StateMachine.enabled_events`.

```{testsetup}

>>> from statemachine import StateMachine, State

```

```py
>>> class ApprovalMachine(StateMachine):
... pending = State(initial=True)
... approved = State(final=True)
... rejected = State(final=True)
...
... approve = pending.to(approved, cond="is_manager")
... reject = pending.to(rejected)
...
... is_manager = False

>>> sm = ApprovalMachine()

>>> [e.id for e in sm.allowed_events]
['approve', 'reject']

>>> [e.id for e in sm.enabled_events()]
['reject']

>>> sm.is_manager = True

>>> [e.id for e in sm.enabled_events()]
['approve', 'reject']

```

`enabled_events` is a method (not a property) because conditions may depend on runtime
arguments. Any `*args`/`**kwargs` passed to `enabled_events()` are forwarded to the
condition callbacks, just like when triggering an event:

```py
>>> class TaskMachine(StateMachine):
... idle = State(initial=True)
... running = State(final=True)
...
... start = idle.to(running, cond="has_enough_resources")
...
... def has_enough_resources(self, cpu=0):
... return cpu >= 4

>>> sm = TaskMachine()

>>> sm.enabled_events()
[]

>>> [e.id for e in sm.enabled_events(cpu=8)]
['start']

```

```{tip}
This is useful for UI scenarios where you want to show or hide buttons based on whether
an event's conditions are currently satisfied.
```

```{note}
An event is considered **enabled** if at least one of its transitions from the current state
has all conditions satisfied. If a condition raises an exception, the event is treated as
enabled (permissive behavior).
```

## Validators


Expand Down
28 changes: 28 additions & 0 deletions statemachine/engines/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ async def _trigger(self, trigger_data: TriggerData):

return result if executed else None

async def enabled_events(self, *args, **kwargs):
sm = self.sm
enabled = {}
for transition in sm.current_state.transitions:
for event in transition.events:
if event in enabled:
continue
extended_kwargs = kwargs.copy()
extended_kwargs.update(
{
"machine": sm,
"model": sm.model,
"event": getattr(sm, event),
"source": transition.source,
"target": transition.target,
"state": sm.current_state,
"transition": transition,
}
)
try:
if await sm._callbacks.async_all(
transition.cond.key, *args, **extended_kwargs
):
enabled[event] = getattr(sm, event)
except Exception:
enabled[event] = getattr(sm, event)
return list(enabled.values())

async def _activate(self, trigger_data: TriggerData, transition: "Transition"):
event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
Expand Down
26 changes: 26 additions & 0 deletions statemachine/engines/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,32 @@ def _trigger(self, trigger_data: TriggerData):

return result if executed else None

def enabled_events(self, *args, **kwargs):
sm = self.sm
enabled = {}
for transition in sm.current_state.transitions:
for event in transition.events:
if event in enabled:
continue
extended_kwargs = kwargs.copy()
extended_kwargs.update(
{
"machine": sm,
"model": sm.model,
"event": getattr(sm, event),
"source": transition.source,
"target": transition.target,
"state": sm.current_state,
"transition": transition,
}
)
try:
if sm._callbacks.all(transition.cond.key, *args, **extended_kwargs):
enabled[event] = getattr(sm, event)
except Exception:
enabled[event] = getattr(sm, event)
return list(enabled.values())

def _activate(self, trigger_data: TriggerData, transition: "Transition"):
event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
Expand Down
70 changes: 62 additions & 8 deletions statemachine/spec_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import operator
import re
from functools import reduce
from inspect import isawaitable
from typing import Callable

replacements = {"!": "not ", "^": " and ", "v": " or "}
Expand Down Expand Up @@ -33,8 +34,15 @@ def match_func(match):


def custom_not(predicate: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return not predicate(*args, **kwargs)
def decorated(*args, **kwargs):
result = predicate(*args, **kwargs)
if isawaitable(result):

async def _negate():
return not await result

return _negate()
return not result

decorated.__name__ = f"not({predicate.__name__})"
unique_key = getattr(predicate, "unique_key", "")
Expand All @@ -43,17 +51,53 @@ def decorated(*args, **kwargs) -> bool:


def custom_and(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return]
def decorated(*args, **kwargs):
left_result = left(*args, **kwargs)
if isawaitable(left_result):

async def _async_and():
lr = await left_result
if not lr:
return lr
rr = right(*args, **kwargs)
if isawaitable(rr):
return await rr
return rr

return _async_and()
if not left_result:
return left_result
right_result = right(*args, **kwargs)
if isawaitable(right_result):
return right_result
return right_result

decorated.__name__ = f"({left.__name__} and {right.__name__})"
decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined]
return decorated


def custom_or(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
def decorated(*args, **kwargs):
left_result = left(*args, **kwargs)
if isawaitable(left_result):

async def _async_or():
lr = await left_result
if lr:
return lr
rr = right(*args, **kwargs)
if isawaitable(rr):
return await rr
return rr

return _async_or()
if left_result:
return left_result
right_result = right(*args, **kwargs)
if isawaitable(right_result):
return right_result
return right_result

decorated.__name__ = f"({left.__name__} or {right.__name__})"
decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined]
Expand All @@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable:
operator_repr = comparison_repr[operator]

def custom_comparator(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return bool(operator(left(*args, **kwargs), right(*args, **kwargs)))
def decorated(*args, **kwargs):
left_result = left(*args, **kwargs)
right_result = right(*args, **kwargs)
if isawaitable(left_result) or isawaitable(right_result):

async def _async_compare():
lr = (await left_result) if isawaitable(left_result) else left_result
rr = (await right_result) if isawaitable(right_result) else right_result
return bool(operator(lr, rr))

return _async_compare()
return bool(operator(left_result, right_result))

decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})"
decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined]
Expand Down
19 changes: 19 additions & 0 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __setstate__(self, state):
self._register_callbacks([])
self.add_listener(*listeners.keys())
self._engine = self._get_engine(rtc)
self._engine.start()

def _get_initial_state(self):
initial_state_value = self.start_value if self.start_value else self.initial_state.value
Expand Down Expand Up @@ -294,6 +295,24 @@ def allowed_events(self) -> "List[Event]":
"""List of the current allowed events."""
return [getattr(self, event) for event in self.current_state.transitions.unique_events]

def enabled_events(self, *args, **kwargs):
"""List of the current enabled events, considering guard conditions.

An event is **enabled** if at least one of its transitions from the current
state has all ``cond``/``unless`` guards satisfied.

Args:
*args: Positional arguments forwarded to condition callbacks.
**kwargs: Keyword arguments forwarded to condition callbacks.

Returns:
A list of enabled :ref:`Event` instances.
"""
result = self._engine.enabled_events(*args, **kwargs)
if not isawaitable(result):
return result
return run_async_from_sync(result)

def _put_nonblocking(self, trigger_data: TriggerData):
"""Put the trigger on the queue without blocking the caller."""
self._engine.put(trigger_data)
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/user_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class UserStatusMachine(StateMachine):
def on_signup(self, token: str):
if token == "":
raise ValueError("Token is required")
self.model.verified = True
self.model.verified = True # type: ignore[union-attr]


class UserExperienceMachine(StateMachine):
Expand Down
Loading
Loading