diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a062e702..ecbebb33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: pass_filenames: false - id: pytest name: Pytest - entry: uv run pytest -n auto + entry: uv run pytest -n auto --cov-fail-under=100 types: [python] language: system pass_filenames: false diff --git a/docs/index.md b/docs/index.md index 80cf75b0..4f3a9988 100644 --- a/docs/index.md +++ b/docs/index.md @@ -19,6 +19,7 @@ integrations diagram weighted_transitions processing_model +invoke statecharts api auto_examples/index diff --git a/docs/invoke.md b/docs/invoke.md new file mode 100644 index 00000000..8aa98c8c --- /dev/null +++ b/docs/invoke.md @@ -0,0 +1,443 @@ +(invoke)= +# Invoke + +Invoke lets a state spawn external work — API calls, file I/O, child state machines — +when it is entered, and automatically cancel that work when the state is exited. This +follows the [SCXML `` semantics](https://www.w3.org/TR/scxml/#invoke) and is +similar to the **do activity** (`do/`) concept in UML Statecharts — an ongoing behavior +that runs for the duration of a state and is cancelled when the state is exited. + +## Execution model + +Invoke handlers run **outside** the main state machine processing loop: + +- **Sync engine**: each invoke handler runs in a **daemon thread**. +- **Async engine**: each invoke handler runs in a **thread executor** + (`loop.run_in_executor`), wrapped in an `asyncio.Task`. The executor is used because + invoke handlers are expected to perform blocking I/O (network calls, file access, + subprocess communication) that would freeze the event loop if run directly. + +When a handler completes, a `done.invoke..` event is automatically sent back +to the machine. If the handler raises an exception, an `error.execution` event is sent +instead. If the owning state is exited before the handler finishes, the invocation is +**cancelled** — `ctx.cancelled` is set and `on_cancel()` is called on `IInvoke` handlers. + +## Callback group + +Invoke is a first-class callback group, just like `enter` and `exit`. This means +convention naming (`on_invoke_`), decorators (`@state.invoke`), inline callables, +and the full {ref}`SignatureAdapter ` dependency injection all work out of the box. + +## Quick start + +The simplest invoke is a plain callable passed to the `invoke` parameter. Here we read a +config file in a background thread and transition to `ready` when the data is available: + +```py +>>> import json +>>> import tempfile +>>> import time +>>> from pathlib import Path +>>> from statemachine import State, StateChart + +>>> config_file = Path(tempfile.mktemp(suffix=".json")) +>>> _ = config_file.write_text('{"db_host": "localhost", "db_port": 5432}') + +>>> def load_config(): +... return json.loads(config_file.read_text()) + +>>> class ConfigLoader(StateChart): +... loading = State(initial=True, invoke=load_config) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) +... +... def on_enter_ready(self, data=None, **kwargs): +... self.config = data + +>>> sm = ConfigLoader() +>>> time.sleep(0.2) + +>>> "ready" in sm.configuration_values +True +>>> sm.config +{'db_host': 'localhost', 'db_port': 5432} + +>>> config_file.unlink() + +``` + +When `loading` is entered, `load_config()` runs in a background thread. When it returns, +a `done.invoke.loading.` event is automatically sent to the machine, triggering +the `done_invoke_loading` transition. The return value is available as the `data` +keyword argument in callbacks on the target state. + +## Naming conventions + +Like `on_enter_` and `on_exit_`, invoke supports naming conventions: + +- `on_invoke_state` — generic, called for every state with invoke +- `on_invoke_` — specific to a state + +```py +>>> config_file = Path(tempfile.mktemp(suffix=".json")) +>>> _ = config_file.write_text('{"feature_flags": ["dark_mode", "beta_api"]}') + +>>> class FeatureLoader(StateChart): +... loading = State(initial=True) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) +... +... def on_invoke_loading(self, **kwargs): +... """Naming convention: on_invoke_.""" +... return json.loads(config_file.read_text()) +... +... def on_enter_ready(self, data=None, **kwargs): +... self.features = data + +>>> sm = FeatureLoader() +>>> time.sleep(0.2) + +>>> "ready" in sm.configuration_values +True +>>> sm.features["feature_flags"] +['dark_mode', 'beta_api'] + +>>> config_file.unlink() + +``` + +## Decorator syntax + +Use the `@state.invoke` decorator: + +```py +>>> config_file = Path(tempfile.mktemp(suffix=".txt")) +>>> _ = config_file.write_text("line 1\nline 2\nline 3\n") + +>>> class LineCounter(StateChart): +... counting = State(initial=True) +... done = State(final=True) +... done_invoke_counting = counting.to(done) +... +... @counting.invoke +... def count_lines(self, **kwargs): +... text = config_file.read_text() +... return len(text.splitlines()) +... +... def on_enter_done(self, data=None, **kwargs): +... self.total_lines = data + +>>> sm = LineCounter() +>>> time.sleep(0.2) + +>>> "done" in sm.configuration_values +True +>>> sm.total_lines +3 + +>>> config_file.unlink() + +``` + +## `done.invoke` transitions + +Use the `done_invoke_` naming convention to declare transitions that fire when +an invoke handler completes: + +```py +>>> config_file = Path(tempfile.mktemp(suffix=".json")) +>>> _ = config_file.write_text('{"version": "3.0.0"}') + +>>> class VersionChecker(StateChart): +... checking = State(initial=True, invoke=lambda: json.loads(config_file.read_text())) +... checked = State(final=True) +... done_invoke_checking = checking.to(checked) +... +... def on_enter_checked(self, data=None, **kwargs): +... self.version = data["version"] + +>>> sm = VersionChecker() +>>> time.sleep(0.2) + +>>> "checked" in sm.configuration_values +True +>>> sm.version +'3.0.0' + +>>> config_file.unlink() + +``` + +The `done_invoke_` prefix maps to the `done.invoke.` event family, +matching any invoke completion for that state regardless of the specific invoke ID. + +## IInvoke protocol + +For advanced use cases, implement the `IInvoke` protocol. This gives you access to +the `InvokeContext` — with the invoke ID, cancellation signal, event kwargs, and a +reference to the parent machine: + +```py +>>> from statemachine.invoke import IInvoke, InvokeContext + +>>> class FileReader: +... """Reads a file and returns its content. Supports cancellation.""" +... def run(self, ctx: InvokeContext): +... # ctx.invokeid — unique ID for this invocation +... # ctx.state_id — the state that triggered invoke +... # ctx.cancelled — threading.Event, set when state exits +... # ctx.send — send events to parent machine +... # ctx.machine — reference to parent machine +... # ctx.kwargs — keyword arguments from the triggering event +... path = ctx.machine.file_path +... return Path(path).read_text() +... +... def on_cancel(self): +... pass # cleanup resources if needed + +>>> isinstance(FileReader(), IInvoke) +True + +``` + +Pass a class to the `invoke` parameter — each state machine instance gets a fresh handler: + +```py +>>> config_file = Path(tempfile.mktemp(suffix=".csv")) +>>> _ = config_file.write_text("name,age\nAlice,30\nBob,25\n") + +>>> class CSVLoader(StateChart): +... loading = State(initial=True, invoke=FileReader) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) +... +... def __init__(self, file_path, **kwargs): +... self.file_path = file_path +... super().__init__(**kwargs) +... +... def on_enter_ready(self, data=None, **kwargs): +... self.content = data + +>>> sm = CSVLoader(file_path=str(config_file)) +>>> time.sleep(0.2) + +>>> "ready" in sm.configuration_values +True +>>> sm.content +'name,age\nAlice,30\nBob,25\n' + +>>> config_file.unlink() + +``` + +## Cancellation + +When a state with active invoke handlers is exited: + +1. `ctx.cancelled` is set (a `threading.Event`) — handlers should poll this +2. `on_cancel()` is called on `IInvoke` handlers (if defined) +3. For the async engine, the asyncio Task is cancelled + +Events from cancelled invocations are silently ignored. + +```py +>>> cancel_called = [] + +>>> class SlowFileReader: +... def run(self, ctx: InvokeContext): +... ctx.cancelled.wait(timeout=5.0) +... +... def on_cancel(self): +... cancel_called.append(True) + +>>> class CancelMachine(StateChart): +... loading = State(initial=True, invoke=SlowFileReader) +... stopped = State(final=True) +... cancel = loading.to(stopped) + +>>> sm = CancelMachine() +>>> time.sleep(0.05) +>>> sm.send("cancel") +>>> time.sleep(0.05) +>>> cancel_called +[True] + +``` + +## Event data propagation + +When a state with invoke handlers is entered via an event, the keyword arguments from +that event are forwarded to the invoke handlers. Plain callables receive them via +{ref}`SignatureAdapter ` dependency injection; `IInvoke` handlers receive them +via `ctx.kwargs`: + +```py +>>> config_file = Path(tempfile.mktemp(suffix=".json")) +>>> _ = config_file.write_text('{"debug": true}') + +>>> class ConfigByName(StateChart): +... idle = State(initial=True) +... loading = State() +... ready = State(final=True) +... start = idle.to(loading) +... done_invoke_loading = loading.to(ready) +... +... def on_invoke_loading(self, file_name=None, **kwargs): +... """file_name comes from send('start', file_name=...).""" +... return json.loads(Path(file_name).read_text()) +... +... def on_enter_ready(self, data=None, **kwargs): +... self.config = data + +>>> sm = ConfigByName() +>>> sm.send("start", file_name=str(config_file)) +>>> time.sleep(0.2) + +>>> "ready" in sm.configuration_values +True +>>> sm.config +{'debug': True} + +>>> config_file.unlink() + +``` + +For initial states (entered automatically, not via an event), `kwargs` is empty. + +## Error handling + +If an invoke handler raises an exception, `error.execution` is sent to the machine's +internal queue (when `error_on_execution=True`, the default for `StateChart`). You can +handle it with a transition for `error.execution`: + +```py +>>> class MissingFileLoader(StateChart): +... loading = State( +... initial=True, +... invoke=lambda: Path("/tmp/nonexistent_file_12345.json").read_text(), +... ) +... error_state = State(final=True) +... error_execution = loading.to(error_state) +... +... def on_enter_error_state(self, error=None, **kwargs): +... self.error_type = type(error).__name__ + +>>> sm = MissingFileLoader() +>>> time.sleep(0.2) + +>>> "error_state" in sm.configuration_values +True +>>> sm.error_type +'FileNotFoundError' + +``` + +## Multiple invokes + +### Independent invokes (one event each) + +Pass a list to run multiple handlers concurrently. Each handler gets its own +`done.invoke..` event — the **first** one to complete triggers the +`done_invoke_` transition (the remaining events are ignored if the state +was already exited): + +```py +>>> file_a = Path(tempfile.mktemp(suffix=".txt")) +>>> file_b = Path(tempfile.mktemp(suffix=".txt")) +>>> _ = file_a.write_text("hello") +>>> _ = file_b.write_text("world") + +>>> class MultiLoader(StateChart): +... loading = State( +... initial=True, +... invoke=[lambda: file_a.read_text(), lambda: file_b.read_text()], +... ) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) + +>>> sm = MultiLoader() +>>> time.sleep(0.2) + +>>> "ready" in sm.configuration_values +True + +>>> file_a.unlink() +>>> file_b.unlink() + +``` + +This follows the [SCXML spec](https://www.w3.org/TR/scxml/#invoke): each `` +is independent and generates its own completion event. Use this when you only need +**any one** of the handlers to complete, or when each invoke is handled by a +separate transition. + +### Grouped invokes (wait for all) + +Use {func}`~statemachine.invoke.invoke_group` to run multiple callables concurrently +and wait for **all** of them to complete before sending a single `done.invoke` event. +The `data` is a list of results in the same order as the input callables: + +```py +>>> from statemachine.invoke import invoke_group + +>>> file_a = Path(tempfile.mktemp(suffix=".txt")) +>>> file_b = Path(tempfile.mktemp(suffix=".txt")) +>>> _ = file_a.write_text("hello") +>>> _ = file_b.write_text("world") + +>>> class BatchLoader(StateChart): +... loading = State( +... initial=True, +... invoke=invoke_group( +... lambda: file_a.read_text(), +... lambda: file_b.read_text(), +... ), +... ) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) +... +... def on_enter_ready(self, data=None, **kwargs): +... self.results = data + +>>> sm = BatchLoader() +>>> time.sleep(0.2) + +>>> "ready" in sm.configuration_values +True +>>> sm.results +['hello', 'world'] + +>>> file_a.unlink() +>>> file_b.unlink() + +``` + +If any callable raises, the remaining ones are cancelled and an `error.execution` +event is sent. If the owning state is exited before all callables finish, the group +is cancelled. + +## Child state machines + +Pass a `StateChart` subclass to spawn a child machine: + +```python +from statemachine import State, StateChart + +class ChildMachine(StateChart): + start = State(initial=True) + end = State(final=True) + go = start.to(end) + + def on_enter_start(self, **kwargs): + self.send("go") + +class ParentMachine(StateChart): + loading = State(initial=True, invoke=ChildMachine) + ready = State(final=True) + done_invoke_loading = loading.to(ready) +``` + +The child machine is instantiated and run when the parent's `loading` state is entered. +When the child terminates (reaches a final state), a `done.invoke` event is sent to the +parent, triggering the `done_invoke_loading` transition. See +`tests/test_invoke.py::TestInvokeStateChartChild` for a working example. diff --git a/docs/releases/3.0.0.md b/docs/releases/3.0.0.md index 54bc94fc..256f0685 100644 --- a/docs/releases/3.0.0.md +++ b/docs/releases/3.0.0.md @@ -15,11 +15,58 @@ Statecharts are a powerful extension to state machines, in a way to organize com The support for statecharts in this release follows the [SCXML specification](https://www.w3.org/TR/scxml/)*, which is a W3C standard for statecharts notation. Adhering as much as possible to this specification ensures compatibility with other tools and platforms that also implement SCXML, but more important, sets a standard on the expected behaviour that the library should assume on various edge cases, enabling easier integration and interoperability in complex systems. -To verify the standard adoption, now the automated tests suite includes several `.scxml` testcases provided by the W3C group. Many thanks for this amazing work! Some of the tests are still failing, and some of the tags are still not implemented like `` , in such cases, we've added an `xfail` mark by including a `test.scxml.md` markdown file with details of the execution output. +To verify the standard adoption, now the automated tests suite includes several `.scxml` testcases provided by the W3C group. Many thanks for this amazing work! Some of the tests are still failing, in such cases, we've added an `xfail` mark by including a `test.scxml.md` markdown file with details of the execution output. While these are exiting news for the library and our community, it also introduces several backwards incompatible changes. Due to the major version release, the new behaviour is assumed by default, but we put a lot of effort to minimize the changes needed in your codebase, and also introduced a few configuration options that you can enable to restore the old behaviour when possible. The following sections navigate to the new features and includes a migration guide. +### Invoke + +States can now spawn external work when entered and cancel it when exited, following the +SCXML `` semantics (similar to UML's `do/` activity). Handlers run in a daemon +thread (sync engine) or a thread executor wrapped in an asyncio Task (async engine). +Invoke is a first-class callback group — convention naming (`on_invoke_`), +decorators (`@state.invoke`), inline callables, and the full `SignatureAdapter` dependency +injection all work out of the box. + +```py +>>> from statemachine import State, StateChart + +>>> class FetchMachine(StateChart): +... loading = State(initial=True, invoke=lambda: {"status": "ok"}) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) + +>>> sm = FetchMachine() +>>> import time; time.sleep(0.1) # wait for background invoke to complete +>>> "ready" in sm.configuration_values +True + +``` + +Use {func}`~statemachine.invoke.invoke_group` to run multiple callables concurrently +and wait for all results: + +```py +>>> from statemachine.invoke import invoke_group + +>>> class BatchFetch(StateChart): +... loading = State(initial=True, invoke=invoke_group(lambda: "a", lambda: "b")) +... ready = State(final=True) +... done_invoke_loading = loading.to(ready) +... +... def on_enter_ready(self, data=None, **kwargs): +... self.results = data + +>>> sm = BatchFetch() +>>> import time; time.sleep(0.2) +>>> sm.results +['a', 'b'] + +``` + +See {ref}`invoke` for full documentation. + ### Compound states **Compound states** have inner child states. Use `State.Compound` to define them diff --git a/statemachine/callbacks.py b/statemachine/callbacks.py index 22965fae..3da2d9a1 100644 --- a/statemachine/callbacks.py +++ b/statemachine/callbacks.py @@ -46,6 +46,7 @@ class CallbackGroup(IntEnum): PREPARE = auto() ENTER = auto() EXIT = auto() + INVOKE = auto() VALIDATOR = auto() BEFORE = auto() ON = auto() @@ -362,6 +363,20 @@ def all(self, *args, on_error: "Callable[[Exception], None] | None" = None, **kw raise return True + def visit(self, visitor_fn, *args, **kwargs): + """Like call() but delegates execution to visitor_fn for each matching callback.""" + for callback in self: + if callback.condition(*args, **kwargs): + visitor_fn(callback, *args, **kwargs) + + async def async_visit(self, visitor_fn, *args, **kwargs): + """Async variant of visit().""" + for callback in self: + if callback.condition(*args, **kwargs): + result = visitor_fn(callback, *args, **kwargs) + if isawaitable(result): + await result + class CallbacksRegistry: def __init__(self) -> None: @@ -371,6 +386,9 @@ def __init__(self) -> None: def __getitem__(self, key: str) -> CallbacksExecutor: return self._registry[key] + def __contains__(self, key: str) -> bool: + return key in self._registry + def check(self, specs: CallbackSpecList): for meta in specs: if meta.is_convention: @@ -440,6 +458,16 @@ async def async_all( return True return await self._registry[key].async_all(*args, on_error=on_error, **kwargs) + def visit(self, key: str, visitor_fn, *args, **kwargs): + if key not in self._registry: + return + self._registry[key].visit(visitor_fn, *args, **kwargs) + + async def async_visit(self, key: str, visitor_fn, *args, **kwargs): + if key not in self._registry: + return + await self._registry[key].async_visit(visitor_fn, *args, **kwargs) + def str(self, key: str) -> str: if key not in self._registry: return "" diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 40322794..e1049d46 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -171,6 +171,10 @@ async def _exit_states( # type: ignore[override] on_error = self._on_error_handler() for info in ordered_states: + # Cancel invocations for this state before executing exit handlers. + if info.state is not None: # pragma: no branch + self._invoke_manager.cancel_for_state(info.state) + args, kwargs = await self._get_args_kwargs(info.transition, trigger_data) if info.state is not None: # pragma: no branch @@ -242,6 +246,10 @@ async def _enter_states( # noqa: C901 new_configuration=new_configuration, ) + # Mark state for invocation if it has invoke callbacks registered + if target.invoke.key in self.sm._callbacks: + self._invoke_manager.mark_for_invoke(target, trigger_data.kwargs) + # Handle final states if target.final: self._handle_final_state(target, on_entry_result) @@ -358,6 +366,9 @@ async def processing_loop( # noqa: C901 took_events = True await self._run_microstep(enabled_transitions, internal_event) + # Spawn invoke handlers for states entered during this macrostep. + await self._invoke_manager.spawn_pending_async() + # Phase 2: remaining internal events while not self.internal_queue.is_empty(): # pragma: no cover internal_event = self.internal_queue.pop() diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py index c55f51a8..012797bf 100644 --- a/statemachine/engines/base.py +++ b/statemachine/engines/base.py @@ -20,6 +20,7 @@ from ..event_data import TriggerData from ..exceptions import InvalidDefinition from ..exceptions import TransitionNotAllowed +from ..invoke import InvokeManager from ..orderedset import OrderedSet from ..state import HistoryState from ..state import State @@ -94,6 +95,7 @@ def __init__(self, sm: "StateChart"): self.running = True self._processing = Lock() self._cache: Dict = {} # Cache for _get_args_kwargs results + self._invoke_manager = InvokeManager(self) def empty(self): # pragma: no cover return self.external_queue.is_empty() @@ -483,6 +485,10 @@ def _exit_states( on_error = self._on_error_handler() for info in ordered_states: + # Cancel invocations for this state before executing exit handlers. + if info.state is not None: # pragma: no branch + self._invoke_manager.cancel_for_state(info.state) + args, kwargs = self._get_args_kwargs(info.transition, trigger_data) # Execute `onexit` handlers — same per-block error isolation as onentry. @@ -645,6 +651,10 @@ def _enter_states( # noqa: C901 new_configuration=new_configuration, ) + # Mark state for invocation if it has invoke callbacks registered + if target.invoke.key in self.sm._callbacks: + self._invoke_manager.mark_for_invoke(target, trigger_data.kwargs) + # Handle final states if target.final: self._handle_final_state(target, on_entry_result) diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index ce71f807..f1cc52f3 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -107,11 +107,8 @@ def processing_loop(self, caller_future=None): # noqa: C901 took_events = True self._run_microstep(enabled_transitions, internal_event) - # TODO: Invoke platform-specific logic - # for state in sorted(self.states_to_invoke, key=self.entry_order): - # for inv in sorted(state.invoke, key=self.document_order): - # self.invoke(inv) - # self.states_to_invoke.clear() + # Spawn invoke handlers for states entered during this macrostep. + self._invoke_manager.spawn_pending_sync() # Process remaining internal events before external events. # Note: the macrostep loop above already drains the internal queue, diff --git a/statemachine/factory.py b/statemachine/factory.py index 37dd20d2..b7f71ba3 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -274,6 +274,9 @@ def add_from_attributes(cls, attrs): # noqa: C901 event_id = key if key.startswith("error_"): event_id = f"{key} {key.replace('_', '.')}" + elif key.startswith("done_invoke_"): + suffix = key[len("done_invoke_") :] + event_id = f"{key} done.invoke.{suffix}" elif key.startswith("done_state_"): suffix = key[len("done_state_") :] event_id = f"{key} done.state.{suffix}" @@ -283,6 +286,9 @@ def add_from_attributes(cls, attrs): # noqa: C901 event_id = value.id elif key.startswith("error_"): event_id = f"{key} {key.replace('_', '.')}" + elif key.startswith("done_invoke_"): + suffix = key[len("done_invoke_") :] + event_id = f"{key} done.invoke.{suffix}" elif key.startswith("done_state_"): suffix = key[len("done_state_") :] event_id = f"{key} done.state.{suffix}" diff --git a/statemachine/invoke.py b/statemachine/invoke.py new file mode 100644 index 00000000..3ac34fb8 --- /dev/null +++ b/statemachine/invoke.py @@ -0,0 +1,461 @@ +"""Invoke support for StateCharts. + +Invoke lets a state spawn external work (API calls, file I/O, child state machines) +when entered, and cancel it when exited. Invoke is modelled as a callback group +(``CallbackGroup.INVOKE``) so that convention naming (``on_invoke_``), +decorators (``@state.invoke``), and inline callables all work out of the box. +""" + +import asyncio +import logging +import threading +import uuid +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple +from typing import runtime_checkable + +try: + from typing import Protocol +except ImportError: # pragma: no cover + from typing_extensions import Protocol # type: ignore[assignment] + +if TYPE_CHECKING: + from .callbacks import CallbackWrapper + from .engines.base import BaseEngine + from .state import State + from .statemachine import StateChart + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class IInvoke(Protocol): + """Protocol for advanced invoke handlers. + + Implement ``run(ctx)`` to execute work when a state is entered. + Optionally implement ``on_cancel()`` for cleanup when the state is exited. + """ + + def run(self, ctx: "InvokeContext") -> Any: ... # pragma: no branch + + +class _InvokeCallableWrapper: + """Wraps an IInvoke class/instance or StateChart class for the callback system. + + The callback resolution system expects plain callables or strings. This wrapper + makes IInvoke classes, IInvoke instances, and StateChart classes look like regular + callables while preserving the original object for the InvokeManager to detect. + + When ``_invoke_handler`` is a **class**, ``run()`` instantiates it on each call + so that each StateChart instance gets its own handler — avoiding shared mutable + state between machines. + """ + + def __init__(self, handler: Any): + self._invoke_handler = handler + self._is_class = isinstance(handler, type) + self._instance: Any = None + name = getattr(handler, "__name__", type(handler).__name__) + self.__name__ = name + self.__qualname__ = getattr(handler, "__qualname__", name) + # The callback system inspects __code__ for caching (signature.py) + self.__code__ = self.__call__.__code__ + + def __call__(self, **kwargs): + return self._invoke_handler + + def run(self, ctx: "InvokeContext") -> Any: + """Create a fresh instance (if class) and delegate to its ``run()``.""" + handler = self._invoke_handler + if self._is_class: + handler = handler() + self._instance = handler + return handler.run(ctx) + + def on_cancel(self): + """Delegate to the live instance's ``on_cancel()`` if available.""" + if self._instance is not None: + target = self._instance + elif self._is_class: + return # Handler hasn't been instantiated yet — nothing to cancel + else: + target = self._invoke_handler + if hasattr(target, "on_cancel"): + target.on_cancel() + + +def normalize_invoke_callbacks(invoke: Any) -> Any: + """Wrap IInvoke instances and StateChart classes so the callback system can handle them. + + Plain callables and strings pass through unchanged. + """ + if invoke is None: + return None + + from .utils import ensure_iterable + + items = ensure_iterable(invoke) + result = [] + for item in items: + if _needs_wrapping(item): + result.append(_InvokeCallableWrapper(item)) + else: + result.append(item) + return result + + +def _needs_wrapping(item: Any) -> bool: + """Check if an item needs wrapping for the callback system.""" + if isinstance(item, str): + return False + if isinstance(item, _InvokeCallableWrapper): + return False + # IInvoke instance (already instantiated — kept for advanced use / SCXML adapter) + if isinstance(item, IInvoke): + return True + if isinstance(item, type): + from .statemachine import StateChart + + # StateChart subclass → child machine invoker + if issubclass(item, StateChart): + return True + return False + + +@dataclass +class InvokeContext: + """Context passed to invoke handlers.""" + + invokeid: str + """Unique identifier for this invocation.""" + + state_id: str + """The id of the state that triggered this invocation.""" + + send: "Callable[..., None]" + """``send(event, **data)`` — enqueue an event on the parent machine's external queue.""" + + machine: "StateChart" + """Reference to the parent state machine.""" + + cancelled: threading.Event = field(default_factory=threading.Event) + """Set when the owning state is exited; handlers should check this to stop early.""" + + kwargs: dict = field(default_factory=dict) + """Keyword arguments from the event that triggered the state entry.""" + + +@dataclass +class Invocation: + """Tracks a single active invocation.""" + + invokeid: str + state_id: str + ctx: InvokeContext + thread: "threading.Thread | None" = None + task: "asyncio.Task[Any] | None" = None + terminated: bool = False + _handler: Any = None + + +class StateChartInvoker: + """Wraps a :class:`StateChart` subclass as an :class:`IInvoke` handler. + + When ``run(ctx)`` is called, it instantiates and runs the child machine + synchronously. The child machine's final result (if any) becomes the + return value. + """ + + def __init__(self, child_class: "type[StateChart]"): + self._child_class = child_class + self._child: "StateChart | None" = None + + def run(self, _ctx: "InvokeContext") -> Any: + self._child = self._child_class() + # The child machine starts automatically in its constructor. + # If it has final states, it will terminate on its own. + return None + + def on_cancel(self): + # Child machine cleanup — currently a no-op since sync machines + # run to completion in the constructor. + self._child = None + + +class InvokeGroup: + """Runs multiple callables concurrently and returns their results as a list. + + All callables are submitted to a :class:`~concurrent.futures.ThreadPoolExecutor`. + The handler blocks until every callable completes, then returns a list of results + in the same order as the input callables. + + If the owning state is exited before all callables finish, the remaining futures + are cancelled. If any callable raises, the remaining futures are cancelled and + the exception propagates (which causes an ``error.execution`` event). + """ + + def __init__(self, callables: "List[Callable[..., Any]]"): + self._callables = list(callables) + self._futures: "List[Future[Any]]" = [] + self._executor: "ThreadPoolExecutor | None" = None + + def run(self, ctx: "InvokeContext") -> "List[Any]": + results: "List[Any]" = [None] * len(self._callables) + self._executor = ThreadPoolExecutor(max_workers=len(self._callables)) + try: + self._futures = [self._executor.submit(fn) for fn in self._callables] + for idx, future in enumerate(self._futures): + # Poll so we can react to cancellation promptly. + while not future.done(): + if ctx.cancelled.is_set(): + self._cancel_remaining() + return [] + ctx.cancelled.wait(timeout=0.05) + results[idx] = future.result() # re-raises if the callable failed + except Exception: + self._cancel_remaining() + raise + finally: + self._executor.shutdown(wait=False) + return results + + def on_cancel(self): + self._cancel_remaining() + if self._executor is not None: + self._executor.shutdown(wait=False) + + def _cancel_remaining(self): + for future in self._futures: + if not future.done(): + future.cancel() + + +def invoke_group(*callables: "Callable[..., Any]") -> InvokeGroup: + """Group multiple callables into a single invoke that runs them concurrently. + + Returns an :class:`InvokeGroup` instance (implements :class:`IInvoke`). + When all callables complete, a single ``done.invoke`` event is sent with + ``data`` set to a list of results in the same order as the input callables. + + Example:: + + loading = State(initial=True, invoke=invoke_group(fetch_users, fetch_config)) + + def on_enter_ready(self, data=None, **kwargs): + users, config = data + """ + return InvokeGroup(list(callables)) + + +class InvokeManager: + """Manages the lifecycle of invoke handlers for a state machine engine. + + Tracks which states need invocation after entry, spawns handlers + (in threads for sync, as tasks for async), and cancels them on exit. + """ + + def __init__(self, engine: "BaseEngine"): + self._engine = engine + self._active: Dict[str, Invocation] = {} + self._pending: "List[Tuple[State, dict]]" = [] + + @property + def sm(self) -> "StateChart": + return self._engine.sm + + # --- Engine hooks --- + + def mark_for_invoke(self, state: "State", event_kwargs: "dict | None" = None): + """Called by ``_enter_states()`` after entering a state with invoke callbacks. + + Args: + state: The state that was entered. + event_kwargs: Keyword arguments from the event that triggered the + state entry. These are forwarded to invoke handlers via + dependency injection (plain callables) and ``InvokeContext.kwargs`` + (IInvoke handlers). + """ + self._pending.append((state, event_kwargs or {})) + + def cancel_for_state(self, state: "State"): + """Called by ``_exit_states()`` before exiting a state.""" + for inv_id, inv in list(self._active.items()): + if inv.state_id == state.id and not inv.terminated: + self._cancel(inv_id) + self._pending = [(s, kw) for s, kw in self._pending if s is not state] + + def cancel_all(self): + """Cancel all active invocations.""" + for inv_id in list(self._active.keys()): + self._cancel(inv_id) + + # --- Sync spawning --- + + def spawn_pending_sync(self): + """Spawn invoke handlers for all states marked for invocation (sync engine).""" + pending = sorted(self._pending, key=lambda p: p[0].document_order) + self._pending.clear() + for state, event_kwargs in pending: + self.sm._callbacks.visit( + state.invoke.key, + self._spawn_one_sync, + state=state, + event_kwargs=event_kwargs, + ) + + def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs): + state: "State" = kwargs["state"] + event_kwargs: dict = kwargs.get("event_kwargs", {}) + ctx = self._make_context(state, event_kwargs) + invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx) + + # Use meta.func to find the original (unwrapped) handler; the callback + # system wraps everything in a signature_adapter closure. + handler = self._resolve_handler(callback.meta.func) + invocation._handler = handler + self._active[ctx.invokeid] = invocation + + thread = threading.Thread( + target=self._run_sync_handler, + args=(callback, handler, ctx, invocation), + daemon=True, + ) + invocation.thread = thread + thread.start() + + def _run_sync_handler( + self, + callback: "CallbackWrapper", + handler: "Any | None", + ctx: InvokeContext, + invocation: Invocation, + ): + try: + if handler is not None: + result = handler.run(ctx) + else: + result = callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs) + if not ctx.cancelled.is_set(): + self.sm.send( + f"done.invoke.{ctx.invokeid}", + data=result, + internal=True, + ) + except Exception as e: + if not ctx.cancelled.is_set(): + self.sm.send("error.execution", error=e, internal=True) + finally: + invocation.terminated = True + + # --- Async spawning --- + + async def spawn_pending_async(self): + """Spawn invoke handlers for all states marked for invocation (async engine).""" + pending = sorted(self._pending, key=lambda p: p[0].document_order) + self._pending.clear() + for state, event_kwargs in pending: + await self.sm._callbacks.async_visit( + state.invoke.key, + self._spawn_one_async, + state=state, + event_kwargs=event_kwargs, + ) + + def _spawn_one_async(self, callback: "CallbackWrapper", **kwargs): + state: "State" = kwargs["state"] + event_kwargs: dict = kwargs.get("event_kwargs", {}) + ctx = self._make_context(state, event_kwargs) + invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx) + + handler = self._resolve_handler(callback.meta.func) + invocation._handler = handler + self._active[ctx.invokeid] = invocation + + loop = asyncio.get_running_loop() + task = loop.create_task(self._run_async_handler(callback, handler, ctx, invocation)) + invocation.task = task + + async def _run_async_handler( + self, + callback: "CallbackWrapper", + handler: "Any | None", + ctx: InvokeContext, + invocation: Invocation, + ): + try: + loop = asyncio.get_running_loop() + if handler is not None: + # Run handler.run(ctx) in a thread executor so blocking I/O + # doesn't freeze the event loop. + result = await loop.run_in_executor(None, handler.run, ctx) + else: + result = await loop.run_in_executor( + None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs) + ) + if not ctx.cancelled.is_set(): + self.sm.send( + f"done.invoke.{ctx.invokeid}", + data=result, + internal=True, + ) + except asyncio.CancelledError: + # Intentionally swallowed: the owning state was exited, so this + # invocation was cancelled — there is nothing to propagate. + return + except Exception as e: + if not ctx.cancelled.is_set(): + self.sm.send("error.execution", error=e, internal=True) + finally: + invocation.terminated = True + + # --- Cancel --- + + def _cancel(self, invokeid: str): + invocation = self._active.get(invokeid) + if not invocation or invocation.terminated: + return + invocation.ctx.cancelled.set() + handler = invocation._handler + if handler is not None and hasattr(handler, "on_cancel"): + try: + handler.on_cancel() + except Exception: + logger.debug("Error in on_cancel for %s", invokeid, exc_info=True) + if invocation.task is not None and not invocation.task.done(): + invocation.task.cancel() + + # --- Helpers --- + + def _make_context(self, state: "State", event_kwargs: "dict | None" = None) -> InvokeContext: + invokeid = f"{state.id}.{uuid.uuid4().hex[:8]}" + return InvokeContext( + invokeid=invokeid, + state_id=state.id, + send=self.sm.send, + machine=self.sm, + kwargs=event_kwargs or {}, + ) + + @staticmethod + def _resolve_handler(underlying: Any) -> "Any | None": + """Determine the handler type from the resolved callable.""" + from .statemachine import StateChart + + if isinstance(underlying, _InvokeCallableWrapper): + inner = underlying._invoke_handler + if isinstance(inner, type) and issubclass(inner, StateChart): + return StateChartInvoker(inner) + return underlying + if isinstance(underlying, IInvoke): + return underlying + if isinstance(underlying, type) and issubclass(underlying, StateChart): + return StateChartInvoker(underlying) + return None diff --git a/statemachine/state.py b/statemachine/state.py index 5337434f..b7eaa20d 100644 --- a/statemachine/state.py +++ b/statemachine/state.py @@ -11,6 +11,7 @@ from .exceptions import InvalidDefinition from .exceptions import StateMachineError from .i18n import _ +from .invoke import normalize_invoke_callbacks from .transition import Transition from .transition_list import TransitionList @@ -205,6 +206,7 @@ def __init__( history: "List[HistoryState] | None" = None, enter: Any = None, exit: Any = None, + invoke: Any = None, donedata: Any = None, _callbacks: Any = None, ): @@ -228,6 +230,9 @@ def __init__( self.exit = self._specs.grouper(CallbackGroup.EXIT).add( exit, priority=CallbackPriority.INLINE ) + self.invoke = self._specs.grouper(CallbackGroup.INVOKE).add( + normalize_invoke_callbacks(invoke), priority=CallbackPriority.INLINE + ) if donedata is not None: if not final: raise InvalidDefinition(_("'donedata' can only be specified on final states.")) @@ -261,6 +266,10 @@ def _setup(self): self.enter.add(f"on_enter_{self.id}", priority=CallbackPriority.NAMING, is_convention=True) self.exit.add("on_exit_state", priority=CallbackPriority.GENERIC, is_convention=True) self.exit.add(f"on_exit_{self.id}", priority=CallbackPriority.NAMING, is_convention=True) + self.invoke.add("on_invoke_state", priority=CallbackPriority.GENERIC, is_convention=True) + self.invoke.add( + f"on_invoke_{self.id}", priority=CallbackPriority.NAMING, is_convention=True + ) def _on_event_defined(self, event: str, transition: Transition, states: List["State"]): """Called by statemachine factory when an event is defined having a transition @@ -386,6 +395,10 @@ def enter(self): def exit(self): return self._ref().exit + @property + def invoke(self): + return self._ref().invoke + def __eq__(self, other): return self._ref() == other diff --git a/tests/conftest.py b/tests/conftest.py index 24802f30..8f1cbdbe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import asyncio +import time from datetime import datetime import pytest @@ -283,6 +285,13 @@ async def processing_loop(self, sm): return await result return result + async def sleep(self, seconds: float): + """Sleep that works for both sync and async engines.""" + if self.is_async: + await asyncio.sleep(seconds) + else: + time.sleep(seconds) + @pytest.fixture(params=["sync", "async"]) def sm_runner(request): diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index e2a3dd24..c00bb4bc 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -2,6 +2,7 @@ import pytest from statemachine.callbacks import CallbackGroup +from statemachine.callbacks import CallbacksExecutor from statemachine.callbacks import CallbackSpec from statemachine.callbacks import CallbackSpecList from statemachine.callbacks import CallbacksRegistry @@ -351,3 +352,35 @@ class ExampleStateMachine(StateChart): match="Error on transition start from Created to Started when resolving callbacks", ): ExampleStateMachine() + + +class TestVisitConditionFalse: + """visit/async_visit skip callbacks whose condition returns False.""" + + def test_visit_skips_when_condition_is_false(self): + visited = [] + spec = CallbackSpec( + "never_called", + group=CallbackGroup.INVOKE, + is_convention=True, + cond=lambda *a, **kw: False, + ) + executor = CallbacksExecutor() + executor.add("test_key", spec, lambda: lambda **kw: True) + + executor.visit(lambda cb, *a, **kw: visited.append(str(cb))) + assert visited == [] + + async def test_async_visit_skips_when_condition_is_false(self): + visited = [] + spec = CallbackSpec( + "never_called", + group=CallbackGroup.INVOKE, + is_convention=True, + cond=lambda *a, **kw: False, + ) + executor = CallbacksExecutor() + executor.add("test_key", spec, lambda: lambda **kw: True) + + await executor.async_visit(lambda cb, *a, **kw: visited.append(str(cb))) + assert visited == [] diff --git a/tests/test_invoke.py b/tests/test_invoke.py new file mode 100644 index 00000000..fb2fce31 --- /dev/null +++ b/tests/test_invoke.py @@ -0,0 +1,989 @@ +"""Tests for the invoke callback group.""" + +import threading +import time + +from statemachine.invoke import IInvoke +from statemachine.invoke import InvokeContext +from statemachine.invoke import invoke_group + +from statemachine import Event +from statemachine import State +from statemachine import StateChart + + +class TestInvokeSimpleCallable: + """Simple callable invoke — function runs in background, done.invoke fires.""" + + async def test_simple_callable_invoke(self, sm_runner): + results = [] + + class SM(StateChart): + loading = State(initial=True, invoke=lambda: 42) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + results.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert results == [42] + + async def test_invoke_return_value_in_done_event(self, sm_runner): + results = [] + + class SM(StateChart): + loading = State(initial=True, invoke=lambda: {"key": "value"}) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + results.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert results == [{"key": "value"}] + + +class TestInvokeNamingConvention: + """Naming convention — on_invoke_() method is discovered and invoked.""" + + async def test_naming_convention(self, sm_runner): + invoked = [] + + class SM(StateChart): + loading = State(initial=True) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_invoke_loading(self, **kwargs): + invoked.append(True) + return "done" + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert invoked == [True] + assert "ready" in sm.configuration_values + + +class TestInvokeDecorator: + """Decorator — @state.invoke handler.""" + + async def test_decorator_invoke(self, sm_runner): + invoked = [] + + class SM(StateChart): + loading = State(initial=True) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + @loading.invoke + def do_work(self, **kwargs): + invoked.append(True) + return "result" + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert invoked == [True] + assert "ready" in sm.configuration_values + + +class TestInvokeIInvokeProtocol: + """IInvoke protocol — class with run(ctx) method.""" + + async def test_iinvoke_class(self, sm_runner): + """Pass an IInvoke class — engine instantiates per SM instance.""" + results = [] + + class MyInvoker: + def run(self, ctx: InvokeContext): + results.append(ctx.state_id) + return "invoker_result" + + def on_cancel(self): + pass # no-op: only verifying the protocol is satisfied + + assert isinstance(MyInvoker(), IInvoke) + + class SM(StateChart): + loading = State(initial=True, invoke=MyInvoker) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + results.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "loading" in results + assert "invoker_result" in results + assert "ready" in sm.configuration_values + + async def test_each_sm_instance_gets_own_handler(self, sm_runner): + """Each StateChart instance must get a fresh IInvoke instance.""" + handler_ids = [] + + class TrackingInvoker: + def run(self, ctx: InvokeContext): + handler_ids.append(id(self)) + return None + + class SM(StateChart): + loading = State(initial=True, invoke=TrackingInvoker) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + sm1 = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm1) + + sm2 = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm2) + + assert len(handler_ids) == 2 + assert handler_ids[0] != handler_ids[1], "Each SM must get its own handler instance" + + +class TestInvokeCancelOnExit: + """Cancel on exit — ctx.cancelled is set when state is exited.""" + + async def test_cancel_on_exit_sync(self): + """Test cancel in sync mode only — uses threading.Event.wait().""" + from tests.conftest import SMRunner + + sm_runner = SMRunner(is_async=False) + cancel_observed = [] + + class SM(StateChart): + loading = State(initial=True) + cancelled_state = State(final=True) + cancel = loading.to(cancelled_state) + + def on_invoke_loading(self, ctx=None, **kwargs): + if ctx is None: + return + ctx.cancelled.wait(timeout=5.0) + cancel_observed.append(ctx.cancelled.is_set()) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.05) + await sm_runner.send(sm, "cancel") + await sm_runner.sleep(0.1) + + assert cancel_observed == [True] + assert "cancelled_state" in sm.configuration_values + + async def test_cancel_on_exit_with_on_cancel(self, sm_runner): + """Test that on_cancel() is called when state is exited.""" + cancel_called = [] + + class CancelTracker: + def run(self, ctx): + while not ctx.cancelled.is_set(): + ctx.cancelled.wait(0.01) + + def on_cancel(self): + cancel_called.append(True) + + class SM(StateChart): + loading = State(initial=True, invoke=CancelTracker) + cancelled_state = State(final=True) + cancel = loading.to(cancelled_state) + + sm = await sm_runner.start(SM) + # Give the invoke handler time to start in its background thread + await sm_runner.sleep(0.15) + await sm_runner.send(sm, "cancel") + await sm_runner.sleep(0.15) + + assert cancel_called == [True] + assert "cancelled_state" in sm.configuration_values + + +class TestInvokeErrorHandling: + """Error in invoker → error.execution event.""" + + async def test_error_in_invoke(self, sm_runner): + errors = [] + + class SM(StateChart): + loading = State(initial=True, invoke=lambda: 1 / 0) + error_state = State(final=True) + error_execution = loading.to(error_state) + + def on_enter_error_state(self, **kwargs): + errors.append(True) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert errors == [True] + assert "error_state" in sm.configuration_values + + +class TestInvokeMultiple: + """Multiple invokes per state — all run concurrently.""" + + async def test_multiple_invokes(self, sm_runner): + results = [] + lock = threading.Lock() + + def task_a(): + with lock: + results.append("a") + return "a" + + def task_b(): + with lock: + results.append("b") + return "b" + + class SM(StateChart): + loading = State(initial=True, invoke=[task_a, task_b]) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert sorted(results) == ["a", "b"] + + +class TestInvokeStateChartChild: + """StateChart as invoker — child machine runs, completion fires done event.""" + + async def test_statechart_invoker(self, sm_runner): + class ChildMachine(StateChart): + start = State(initial=True) + end = State(final=True) + go = start.to(end) + + def on_enter_start(self, **kwargs): + self.send("go") + + class SM(StateChart): + loading = State(initial=True, invoke=ChildMachine) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + + +class TestDoneInvokeTransition: + """done_invoke_ transition — naming convention works.""" + + async def test_done_invoke_transition(self, sm_runner): + class SM(StateChart): + loading = State(initial=True, invoke=lambda: "hello") + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + + +class TestDoneInvokeEventFormat: + """done.invoke event name must be done.invoke.. (no duplication).""" + + async def test_done_invoke_event_has_no_duplicate_state_id(self, sm_runner): + received_events = [] + + class SM(StateChart): + loading = State(initial=True, invoke=lambda: "ok") + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, event=None, **kwargs): + if event is not None: + received_events.append(str(event)) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert len(received_events) == 1 + event_name = received_events[0] + # Must be "done.invoke.loading." — NOT "done.invoke.loading.loading." + assert event_name.startswith("done.invoke.loading.") + parts = event_name.split(".") + # ["done", "invoke", "loading", ""] — exactly 4 parts + assert len(parts) == 4, f"Expected 4 parts, got {parts}" + + +class TestInvokeGroup: + """invoke_group() — runs multiple callables concurrently, returns list of results.""" + + async def test_group_returns_ordered_results(self, sm_runner): + """Results are returned in the same order as the input callables.""" + results = [] + + def slow(): + time.sleep(0.05) + return "slow" + + def fast(): + return "fast" + + class SM(StateChart): + loading = State(initial=True, invoke=invoke_group(slow, fast)) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + results.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert results == [["slow", "fast"]] + + async def test_group_with_file_io(self, sm_runner, tmp_path): + """Real I/O: read two files concurrently and get both results.""" + file_a = tmp_path / "a.txt" + file_b = tmp_path / "b.txt" + file_a.write_text("hello") + file_b.write_text("world") + + results = [] + + class SM(StateChart): + loading = State( + initial=True, + invoke=invoke_group( + lambda: file_a.read_text(), + lambda: file_b.read_text(), + ), + ) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + results.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert results == [["hello", "world"]] + + async def test_group_error_cancels_remaining(self, sm_runner): + """If one callable raises, error.execution is sent.""" + errors = [] + + def ok(): + time.sleep(0.1) + return "ok" + + def fail(): + raise ValueError("boom") + + class SM(StateChart): + loading = State(initial=True, invoke=invoke_group(ok, fail)) + error_state = State(final=True) + error_execution = loading.to(error_state) + + def on_enter_error_state(self, **kwargs): + errors.append(True) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.3) + await sm_runner.processing_loop(sm) + + assert "error_state" in sm.configuration_values + assert errors == [True] + + async def test_group_cancel_on_exit(self, sm_runner): + """Cancellation propagates: exiting state stops the group.""" + + def slow_task(): + time.sleep(5.0) + return "should not complete" + + class SM(StateChart): + loading = State(initial=True, invoke=invoke_group(slow_task)) + stopped = State(final=True) + cancel = loading.to(stopped) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.05) + await sm_runner.send(sm, "cancel") + await sm_runner.sleep(0.1) + + assert "stopped" in sm.configuration_values + + async def test_group_single_callable(self, sm_runner): + """Edge case: group with a single callable still returns a list.""" + results = [] + + class SM(StateChart): + loading = State(initial=True, invoke=invoke_group(lambda: 42)) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + results.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert results == [[42]] + + async def test_each_sm_instance_gets_own_group(self, sm_runner): + """Each SM instance must get its own InvokeGroup — no shared state.""" + all_results = [] + + counter = {"value": 0} + lock = threading.Lock() + + def counting_task(): + with lock: + counter["value"] += 1 + return counter["value"] + + class SM(StateChart): + loading = State(initial=True, invoke=invoke_group(counting_task)) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + all_results.append(data) + + sm1 = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm1) + + sm2 = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm2) + + assert len(all_results) == 2 + assert all_results[0] == [1] + assert all_results[1] == [2] + + +class TestInvokeEventKwargs: + """Event kwargs from send() are forwarded to invoke handlers.""" + + async def test_plain_callable_receives_event_kwargs(self, sm_runner): + """Plain callable invoke handler receives event kwargs via SignatureAdapter.""" + received = [] + + class SM(StateChart): + idle = State(initial=True) + loading = State() + ready = State(final=True) + start = idle.to(loading) + done_invoke_loading = loading.to(ready) + + def on_invoke_loading(self, file_name=None, **kwargs): + received.append(file_name) + return f"loaded:{file_name}" + + def on_enter_ready(self, data=None, **kwargs): + received.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.send(sm, "start", file_name="config.json") + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert received == ["config.json", "loaded:config.json"] + + async def test_iinvoke_handler_receives_event_kwargs_via_ctx(self, sm_runner): + """IInvoke handler receives event kwargs via ctx.kwargs.""" + received = [] + + class FileLoader: + def run(self, ctx: InvokeContext): + received.append(ctx.kwargs.get("file_name")) + return f"loaded:{ctx.kwargs['file_name']}" + + class SM(StateChart): + idle = State(initial=True) + loading = State(invoke=FileLoader) + ready = State(final=True) + start = idle.to(loading) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + received.append(data) + + sm = await sm_runner.start(SM) + await sm_runner.send(sm, "start", file_name="data.csv") + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert received == ["data.csv", "loaded:data.csv"] + + async def test_initial_state_invoke_has_empty_kwargs(self, sm_runner): + """Invoke on initial state gets empty kwargs (no triggering event).""" + + class SM(StateChart): + loading = State(initial=True, invoke=lambda: 42) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + assert "ready" in sm.configuration_values + + +class TestInvokeNotTriggeredOnNonInvokeState: + """States without invoke handlers should not be affected.""" + + async def test_no_invoke_on_plain_state(self, sm_runner): + class SM(StateChart): + idle = State(initial=True) + active = State() + done = State(final=True) + + go = idle.to(active) + finish = active.to(done) + + sm = await sm_runner.start(SM) + await sm_runner.send(sm, "go") + assert "active" in sm.configuration_values + await sm_runner.send(sm, "finish") + assert "done" in sm.configuration_values + + +class TestInvokeManagerCancelAll: + """InvokeManager.cancel_all() cancels every active invocation.""" + + async def test_cancel_all(self, sm_runner): + class SlowHandler: + def run(self, ctx): + ctx.cancelled.wait(timeout=5.0) + + class SM(StateChart): + loading = State(initial=True, invoke=SlowHandler) + stopped = State(final=True) + cancel = loading.to(stopped) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + sm._engine._invoke_manager.cancel_all() + await sm_runner.sleep(0.15) + + # All invocations should be terminated + for inv in sm._engine._invoke_manager._active.values(): + assert inv.terminated + + +class TestInvokeCancelAlreadyTerminated: + """Cancelling an already-terminated invocation is a no-op.""" + + async def test_cancel_terminated_invocation(self, sm_runner): + class SM(StateChart): + loading = State(initial=True, invoke=lambda: 42) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + # All invocations should be terminated by now + manager = sm._engine._invoke_manager + for inv in manager._active.values(): + assert inv.terminated + # Calling cancel on terminated invocations should be a safe no-op + for inv_id in list(manager._active.keys()): + manager._cancel(inv_id) + + +class TestInvokeOnCancelException: + """Exception in on_cancel() is caught and logged, not propagated.""" + + async def test_on_cancel_exception_is_suppressed(self, sm_runner): + class BadCancelHandler: + def run(self, ctx): + ctx.cancelled.wait(timeout=5.0) + + def on_cancel(self): + raise RuntimeError("on_cancel exploded") + + class SM(StateChart): + loading = State(initial=True, invoke=BadCancelHandler) + stopped = State(final=True) + cancel = loading.to(stopped) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + # This should NOT raise even though on_cancel() raises + await sm_runner.send(sm, "cancel") + await sm_runner.sleep(0.15) + + assert "stopped" in sm.configuration_values + + +class TestStateChartInvokerOnCancel: + """StateChartInvoker.on_cancel() cleans up the child reference.""" + + def test_on_cancel_clears_child(self): + from statemachine.invoke import StateChartInvoker + + class ChildMachine(StateChart): + start = State(initial=True, final=True) + + invoker = StateChartInvoker(ChildMachine) + ctx = InvokeContext( + invokeid="test.123", + state_id="test", + send=lambda *a, **kw: None, + machine=None, + ) + invoker.run(ctx) + assert invoker._child is not None + invoker.on_cancel() + assert invoker._child is None + + +class TestNormalizeInvokeCallbacks: + """normalize_invoke_callbacks handles edge cases.""" + + def test_string_passes_through(self): + from statemachine.invoke import normalize_invoke_callbacks + + result = normalize_invoke_callbacks("some_method_name") + assert result == ["some_method_name"] + + def test_already_wrapped_passes_through(self): + from statemachine.invoke import _InvokeCallableWrapper + from statemachine.invoke import normalize_invoke_callbacks + + class MyHandler: + def run(self, ctx): + pass + + wrapper = _InvokeCallableWrapper(MyHandler) + result = normalize_invoke_callbacks(wrapper) + assert len(result) == 1 + assert result[0] is wrapper + + def test_iinvoke_class_with_run_method(self): + """IInvoke-compatible class gets wrapped.""" + from statemachine.invoke import _InvokeCallableWrapper + from statemachine.invoke import normalize_invoke_callbacks + + class CustomHandler: + def run(self, ctx): + return "result" + + # CustomHandler satisfies IInvoke protocol (has run method) + assert isinstance(CustomHandler(), IInvoke) + result = normalize_invoke_callbacks(CustomHandler) + assert len(result) == 1 + assert isinstance(result[0], _InvokeCallableWrapper) + + def test_plain_callable_passes_through(self): + from statemachine.invoke import _InvokeCallableWrapper + from statemachine.invoke import normalize_invoke_callbacks + + def my_func(): + return 42 + + result = normalize_invoke_callbacks(my_func) + assert len(result) == 1 + assert result[0] is my_func + assert not isinstance(result[0], _InvokeCallableWrapper) + + def test_non_invoke_class_passes_through(self): + """A class without run() (not IInvoke, not StateChart) passes through unwrapped.""" + from statemachine.invoke import _InvokeCallableWrapper + from statemachine.invoke import normalize_invoke_callbacks + + class PlainClass: + pass + + result = normalize_invoke_callbacks(PlainClass) + assert len(result) == 1 + assert result[0] is PlainClass + assert not isinstance(result[0], _InvokeCallableWrapper) + + +class TestResolveHandler: + """InvokeManager._resolve_handler edge cases.""" + + def test_bare_iinvoke_instance(self): + from statemachine.invoke import InvokeManager + + class MyHandler: + def run(self, ctx): + return "result" + + handler = MyHandler() + assert isinstance(handler, IInvoke) + resolved = InvokeManager._resolve_handler(handler) + assert resolved is handler + + def test_bare_statechart_class(self): + from statemachine.invoke import InvokeManager + from statemachine.invoke import StateChartInvoker + + class ChildMachine(StateChart): + start = State(initial=True, final=True) + + resolved = InvokeManager._resolve_handler(ChildMachine) + assert isinstance(resolved, StateChartInvoker) + + def test_plain_callable_returns_none(self): + from statemachine.invoke import InvokeManager + + def my_func(): + return 42 + + assert InvokeManager._resolve_handler(my_func) is None + + +class TestInvokeCallableWrapperOnCancel: + """_InvokeCallableWrapper.on_cancel() edge cases.""" + + def test_on_cancel_non_class_instance_with_on_cancel(self): + """Non-class handler (already instantiated) delegates on_cancel.""" + from statemachine.invoke import _InvokeCallableWrapper + + cancel_called = [] + + class MyHandler: + def run(self, ctx): + return "result" + + def on_cancel(self): + cancel_called.append(True) + + handler = MyHandler() + wrapper = _InvokeCallableWrapper(handler) + # _instance is None, _is_class is False → falls through to _invoke_handler + wrapper.on_cancel() + assert cancel_called == [True] + + def test_on_cancel_class_not_yet_instantiated(self): + """Class handler not yet instantiated — on_cancel is a no-op.""" + from statemachine.invoke import _InvokeCallableWrapper + + class MyHandler: + def run(self, ctx): + return "result" + + def on_cancel(self): + raise RuntimeError("should not be called") + + wrapper = _InvokeCallableWrapper(MyHandler) + # _instance is None, _is_class is True → early return + wrapper.on_cancel() # should not raise + + def test_callable_wrapper_call_returns_handler(self): + """__call__ returns the original handler (used by callback system for resolution).""" + from statemachine.invoke import _InvokeCallableWrapper + + class MyHandler: + def run(self, ctx): + return "result" + + wrapper = _InvokeCallableWrapper(MyHandler) + assert wrapper() is MyHandler + + +class TestInvokeGroupOnCancelBeforeRun: + """InvokeGroup.on_cancel() before run() is a safe no-op.""" + + def test_on_cancel_before_run(self): + group = invoke_group(lambda: 1) + # on_cancel before run — executor is None, no futures + group.on_cancel() + + +class TestDoneInvokeEventFactory: + """done_invoke_ prefix works with both TransitionList and Event.""" + + async def test_done_invoke_with_event_object(self, sm_runner): + """Event() object with done_invoke_ prefix should match done.invoke events.""" + + class SM(StateChart): + loading = State(initial=True, invoke=lambda: "result") + ready = State(final=True) + done_invoke_loading = Event(loading.to(ready)) + + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + + +class TestVisitNoCallbacks: + """visit/async_visit with no registered callbacks is a no-op.""" + + def test_visit_missing_key(self): + from statemachine.callbacks import CallbacksRegistry + + registry = CallbacksRegistry() + # Should not raise — just returns + registry.visit("nonexistent_key", lambda cb, **kw: None) + + async def test_async_visit_missing_key(self): + from statemachine.callbacks import CallbacksRegistry + + registry = CallbacksRegistry() + await registry.async_visit("nonexistent_key", lambda cb, **kw: None) + + +class TestAsyncVisitAwaitable: + """async_visit should await the visitor_fn result when it is awaitable.""" + + async def test_async_visitor_fn_is_awaited(self): + from statemachine.callbacks import CallbackGroup + from statemachine.callbacks import CallbacksExecutor + from statemachine.callbacks import CallbackSpec + + visited = [] + + async def async_visitor(callback, **kwargs): + visited.append(str(callback)) + + executor = CallbacksExecutor() + spec = CallbackSpec("dummy", group=CallbackGroup.INVOKE, is_convention=True) + executor.add("test_key", spec, lambda: lambda **kw: True) + + await executor.async_visit(async_visitor) + assert visited == ["dummy"] + + +class TestIInvokeProtocolRun: + """IInvoke.run() protocol method can be called on a concrete implementation.""" + + def test_protocol_run_is_callable(self): + """Verify that calling run() on a concrete IInvoke instance works.""" + + class ConcreteInvoker: + def run(self, ctx): + return "concrete_result" + + invoker: IInvoke = ConcreteInvoker() + result = invoker.run(None) + assert result == "concrete_result" + + +class TestSpawnPendingAsyncEmpty: + """spawn_pending_async with nothing pending is a no-op.""" + + async def test_spawn_pending_async_no_pending(self, sm_runner): + class SM(StateChart): + idle = State(initial=True) + active = State(final=True) + go = idle.to(active) + + sm = await sm_runner.start(SM) + # Directly call spawn_pending_async with empty pending list + await sm._engine._invoke_manager.spawn_pending_async() + + +class TestInvokeAsyncCancelledDuringExecution: + """Async handler completes or errors after state was already exited.""" + + async def test_success_after_cancel(self): + """Handler returns successfully but ctx.cancelled is already set.""" + from tests.conftest import SMRunner + + class SM(StateChart): + loading = State(initial=True) + stopped = State(final=True) + cancel = loading.to(stopped) + + def on_invoke_loading(self, ctx=None, **kwargs): + if ctx is None: + return + # Simulate: cancelled is set during execution but we still return + ctx.cancelled.set() + return "should_be_ignored" + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + # The done.invoke event should NOT have been sent (cancelled) + assert "loading" in sm.configuration_values + + async def test_error_after_cancel(self): + """Handler raises but ctx.cancelled is already set — error is swallowed.""" + from tests.conftest import SMRunner + + class SM(StateChart): + loading = State(initial=True) + error_state = State(final=True) + error_execution = loading.to(error_state) + + def on_invoke_loading(self, ctx=None, **kwargs): + if ctx is None: + return + # Simulate: cancelled during execution, then error + ctx.cancelled.set() + raise ValueError("should be ignored") + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + # The error.execution event should NOT have been sent (cancelled) + assert "loading" in sm.configuration_values + + +class TestSyncInvokeErrorAfterCancel: + """Sync handler errors after state was already exited.""" + + async def test_sync_error_after_cancel(self): + """Sync handler raises but ctx.cancelled is set — error.execution not sent.""" + from tests.conftest import SMRunner + + class SM(StateChart): + loading = State(initial=True) + error_state = State(final=True) + error_execution = loading.to(error_state) + + def on_invoke_loading(self, ctx=None, **kwargs): + if ctx is None: + return + ctx.cancelled.set() + raise ValueError("should be ignored") + + sm_runner = SMRunner(is_async=False) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "loading" in sm.configuration_values diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py index e02ac8d7..c670bcd0 100644 --- a/tests/test_statemachine.py +++ b/tests/test_statemachine.py @@ -709,3 +709,23 @@ def is_blocked(self): sm = MyMachine() assert [e.id for e in sm.enabled_events()] == ["go"] + + +class TestInvalidStateValueNonNone: + """current_state raises InvalidStateValue when state value is non-None but invalid.""" + + def test_invalid_non_none_state_value(self): + import warnings + + class SM(StateChart): + idle = State(initial=True) + active = State(final=True) + go = idle.to(active) + + sm = SM() + # Bypass setter validation by writing directly to the model attribute + setattr(sm.model, sm.state_field, "nonexistent_state") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + with pytest.raises(exceptions.InvalidStateValue): + _ = sm.current_state