From d4ef917ad529deb356443f6adbcb86e6563a5922 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Wed, 18 Feb 2026 12:47:45 -0300 Subject: [PATCH 1/2] feat: class-level listener declarations with setup() protocol Allow listeners to be declared at class definition time via a `listeners` attribute on StateChart/StateMachine. The list accepts callables (classes, partial, lambdas) as per-instance factories and pre-built instances as shared listeners. - Metaclass collects `_class_listeners` from attrs and MRO - `listeners_inherit = False` to replace instead of extend parent listeners - `setup(sm, **kwargs)` protocol for runtime dependency injection - `active_listeners` public property to inspect attached listeners - Serialization correctly preserves all listeners through pickle/copy --- docs/listeners.md | 188 ++++++++++++++ docs/releases/3.0.0.md | 15 ++ statemachine/factory.py | 22 ++ statemachine/statemachine.py | 42 ++- tests/test_class_listeners.py | 476 ++++++++++++++++++++++++++++++++++ 5 files changed, 740 insertions(+), 3 deletions(-) create mode 100644 tests/test_class_listeners.py diff --git a/docs/listeners.md b/docs/listeners.md index 226ee5d3..4da4378c 100644 --- a/docs/listeners.md +++ b/docs/listeners.md @@ -87,6 +87,194 @@ Paulista Avenue after: red--(cycle)-->green ``` +## Class-level listener declarations + +```{versionadded} 3.0.0 +``` + +You can declare listeners at the class level so they are automatically attached to every +instance of the state machine. This is useful for cross-cutting concerns like logging, +persistence, or telemetry that should always be present. + +The `listeners` class attribute accepts two forms: + +- **Callable** (class, `functools.partial`, lambda): acts as a factory — called once per + SM instance to produce a fresh listener. Use this for listeners that accumulate state. +- **Instance** (pre-built object): shared across all SM instances. Use this for stateless + listeners like a global logger. + +```py +>>> from statemachine import State, StateChart + +>>> class AuditListener: +... def __init__(self): +... self.log = [] +... +... def after_transition(self, event, source, target): +... self.log.append(f"{event}: {source.id} -> {target.id}") + +>>> class OrderMachine(StateChart): +... listeners = [AuditListener] +... +... draft = State(initial=True) +... confirmed = State(final=True) +... confirm = draft.to(confirmed) + +>>> sm = OrderMachine() +>>> sm.send("confirm") +>>> [type(l).__name__ for l in sm.active_listeners] +['AuditListener'] + +>>> sm.active_listeners[0].log +['confirm: draft -> confirmed'] + +``` + +### Listeners with configuration + +Use `functools.partial` to pass configuration to listener factories: + +```py +>>> from functools import partial + +>>> class HistoryListener: +... def __init__(self, max_size=50): +... self.max_size = max_size +... self.entries = [] +... +... def after_transition(self, event, source, target): +... self.entries.append(f"{source.id} -> {target.id}") +... if len(self.entries) > self.max_size: +... self.entries.pop(0) + +>>> class TrackedMachine(StateChart): +... listeners = [partial(HistoryListener, max_size=10)] +... +... s1 = State(initial=True) +... s2 = State(final=True) +... go = s1.to(s2) + +>>> sm = TrackedMachine() +>>> sm.send("go") +>>> sm.active_listeners[0].entries +['s1 -> s2'] + +``` + +### Runtime listeners merge with class-level + +Runtime listeners passed via the `listeners=` constructor parameter are appended after +class-level listeners: + +```py +>>> runtime_listener = AuditListener() +>>> sm = OrderMachine(listeners=[runtime_listener]) +>>> sm.send("confirm") +>>> [type(l).__name__ for l in sm.active_listeners] +['AuditListener', 'AuditListener'] + +>>> runtime_listener.log +['confirm: draft -> confirmed'] + +``` + +### Inheritance + +Child class listeners are appended after parent listeners. The full MRO chain is respected: + +```py +>>> class LogListener: +... pass + +>>> class BaseMachine(StateChart): +... listeners = [LogListener] +... +... s1 = State(initial=True) +... s2 = State(final=True) +... go = s1.to(s2) + +>>> class ChildMachine(BaseMachine): +... listeners = [AuditListener] + +>>> sm = ChildMachine() +>>> [type(l).__name__ for l in sm.active_listeners] +['LogListener', 'AuditListener'] + +``` + +To **replace** parent listeners instead of extending, set `listeners_inherit = False`: + +```py +>>> class ReplacedMachine(BaseMachine): +... listeners_inherit = False +... listeners = [AuditListener] + +>>> sm = ReplacedMachine() +>>> [type(l).__name__ for l in sm.active_listeners] +['AuditListener'] + +``` + +### Listener `setup()` protocol + +Listeners that need runtime dependencies (e.g., a database session, Redis client) can +define a `setup()` method. It is called during SM `__init__` with the SM instance and +any extra `**kwargs` passed to the constructor. The {ref}`dynamic-dispatch` mechanism +ensures each listener receives only the kwargs it declares: + +```py +>>> class DBListener: +... def __init__(self): +... self.session = None +... +... def setup(self, sm, session=None, **kwargs): +... self.session = session + +>>> class PersistentMachine(StateChart): +... listeners = [DBListener] +... +... s1 = State(initial=True) +... s2 = State(final=True) +... go = s1.to(s2) + +>>> sm = PersistentMachine(session="my_db_session") +>>> sm.active_listeners[0].session +'my_db_session' + +``` + +Multiple listeners with different dependencies compose naturally — each `setup()` picks +only the kwargs it needs: + +```py +>>> class CacheListener: +... def __init__(self): +... self.redis = None +... +... def setup(self, sm, redis=None, **kwargs): +... self.redis = redis + +>>> class FullMachine(StateChart): +... listeners = [DBListener, CacheListener] +... +... s1 = State(initial=True) +... s2 = State(final=True) +... go = s1.to(s2) + +>>> sm = FullMachine(session="db_conn", redis="redis_conn") +>>> sm.active_listeners[0].session +'db_conn' +>>> sm.active_listeners[1].redis +'redis_conn' + +``` + +```{note} +The `setup()` method is only called on **factory-created** instances (callable entries). +Shared instances (pre-built objects) do not receive `setup()` calls — they are assumed +to be already configured by whoever created them. +``` + ```{hint} The `StateChart` itself is registered as a listener, so by using `listeners` an external object can have the same level of functionalities provided to the built-in class. diff --git a/docs/releases/3.0.0.md b/docs/releases/3.0.0.md index 5fcb2728..54bc94fc 100644 --- a/docs/releases/3.0.0.md +++ b/docs/releases/3.0.0.md @@ -410,6 +410,21 @@ class GameCharacter(StateChart): See {ref}`weighted-transitions` for full documentation. +### Class-level listener declarations + +Listeners can now be declared at the class level using the `listeners` attribute, so they are +automatically attached to every instance. The list accepts callables (classes, `partial`, lambdas) +as factories that create a fresh listener per instance, or pre-built instances that are shared. + +A `setup()` protocol allows factory-created listeners to receive runtime dependencies +(DB sessions, Redis clients, etc.) via `**kwargs` forwarded from the SM constructor. + +Inheritance is supported: child listeners are appended after parent listeners, unless +`listeners_inherit = False` is set to replace them entirely. + +See {ref}`observers` for full documentation. + + ### Async concurrent event result routing When multiple coroutines send events concurrently via `asyncio.gather`, each diff --git a/statemachine/factory.py b/statemachine/factory.py index 30868059..37dd20d2 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -60,6 +60,7 @@ def __init__( ) cls.add_inherited(bases) cls.add_from_attributes(attrs) + cls._collect_class_listeners(attrs, bases) cls._unpack_builders_callbacks() cls._update_event_references() @@ -233,6 +234,27 @@ def _setup(cls): "send", } | {s.id for s in cls.states} + def _collect_class_listeners(cls, attrs: Dict[str, Any], bases: Tuple[type]): + """Collect class-level listener declarations from attrs and MRO. + + Listeners declared on parent classes are prepended (MRO order), + unless the child sets ``listeners_inherit = False``. + """ + class_listeners: List[Any] = [] + if attrs.get("listeners_inherit", True): + for base in reversed(bases): + class_listeners.extend(getattr(base, "_class_listeners", [])) + for entry in attrs.get("listeners", []): + if entry is None or isinstance(entry, (str, int, float, bool)): + raise InvalidDefinition( + _( + "Invalid entry in 'listeners': {!r}. " + "Expected a class, callable, or listener instance." + ).format(entry) + ) + class_listeners.append(entry) + cls._class_listeners: List[Any] = class_listeners + def add_inherited(cls, bases): for base in bases: for state in getattr(base, "states", []): diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index ff737988..1f1fb1b0 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -29,6 +29,7 @@ from .graph import iterate_states_and_transitions from .i18n import _ from .model import Model +from .signature import SignatureAdapter from .utils import run_async_from_sync if TYPE_CHECKING: @@ -129,6 +130,7 @@ class StateChart(Generic[TModel], metaclass=StateMachineMetaclass): _events: "Dict[Event, None]" _protected_attrs: set _specs: CallbackSpecList + _class_listeners: List[Any] prepare: SpecListGrouper def __init__( @@ -137,6 +139,7 @@ def __init__( state_field: str = "state", start_value: Any = None, listeners: "List[object] | None" = None, + **kwargs: Any, ): self.model: TModel = model if model is not None else Model() # type: ignore[assignment] self.history_values: Dict[ @@ -154,7 +157,9 @@ def __init__( if self._abstract: raise InvalidDefinition(_("There are no states or transitions.")) - self._register_callbacks(listeners or []) + self._class_listener_instances = self._resolve_class_listeners(**kwargs) + all_listeners = self._class_listener_instances + (listeners or []) + self._register_callbacks(all_listeners) # Activate the initial state, this only works if the outer scope is sync code. # for async code, the user should manually call `await sm.activate_initial_state()` @@ -168,6 +173,26 @@ def _get_engine(self): return SyncEngine(self) + def _resolve_class_listeners(self, **kwargs: Any) -> List[object]: + resolved: List[object] = [] + for entry in self._class_listeners: + if callable(entry): + instance = entry() + setup = getattr(instance, "setup", None) + if setup is not None: + sig = SignatureAdapter.from_callable(setup) + ba = sig.bind_expected(self, **kwargs) + try: + setup(*ba.args, **ba.kwargs) + except TypeError as err: + raise TypeError( + f"Error calling setup() on listener {type(instance).__name__}: {err}" + ) from err + else: + instance = entry + resolved.append(instance) + return resolved + def activate_initial_state(self) -> Any: result = self._engine.activate_initial_state() if not isawaitable(result): @@ -199,11 +224,13 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) # type: ignore[attr-defined] self._callbacks = CallbacksRegistry() self._states_for_instance = {} - self._listeners = {} + # _listeners already contained both class-level and runtime listeners + # when serialized, so just re-register them all. self._register_callbacks([]) - self.add_listener(*listeners.values()) + if listeners: + self.add_listener(*listeners.values()) self._engine = self._get_engine() self._engine.start() @@ -268,6 +295,15 @@ def _register_callbacks(self, listeners: List[object]): self._callbacks.async_or_sync() + @property + def active_listeners(self) -> List[object]: + """List of all active listeners attached to this instance. + + Includes class-level listeners (resolved from the ``listeners`` class attribute), + constructor ``listeners=`` parameter, and any added via :meth:`add_listener`. + """ + return list(self._listeners.values()) + def add_listener(self, *listeners): """Add a listener. diff --git a/tests/test_class_listeners.py b/tests/test_class_listeners.py new file mode 100644 index 00000000..fb257ea8 --- /dev/null +++ b/tests/test_class_listeners.py @@ -0,0 +1,476 @@ +import pickle +from functools import partial + +import pytest +from statemachine.exceptions import InvalidDefinition + +from statemachine import State +from statemachine import StateChart + + +class RecordingListener: + """Listener that records transitions for testing.""" + + def __init__(self): + self.transitions = [] + + def after_transition(self, event, source, target): + self.transitions.append((event, source.id, target.id)) + + +class SetupListener: + """Listener that uses setup() to receive runtime dependencies.""" + + def __init__(self): + self.session = None + self.transitions = [] + + def setup(self, sm, session=None, **kwargs): + self.session = session + + def after_transition(self, event, source, target): + self.transitions.append((event, source.id, target.id, self.session)) + + +class TestClassLevelListeners: + def test_class_level_listener_callable_creates_per_instance(self): + class MyChart(StateChart): + listeners = [RecordingListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm1 = MyChart() + sm2 = MyChart() + + sm1.send("go") + + # Each SM gets its own listener instance + assert len(sm1._class_listener_instances) == 1 + assert len(sm2._class_listener_instances) == 1 + assert sm1._class_listener_instances[0] is not sm2._class_listener_instances[0] + + # Only sm1 should have the transition recorded + assert sm1._class_listener_instances[0].transitions == [("go", "s1", "s2")] + assert sm2._class_listener_instances[0].transitions == [] + + def test_class_level_listener_shared_instance(self): + shared = RecordingListener() + + class MyChart(StateChart): + listeners = [shared] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm1 = MyChart() + sm2 = MyChart() + + sm1.send("go") + sm2.send("go") + + # Both SMs share the same listener instance + assert sm1._class_listener_instances[0] is shared + assert sm2._class_listener_instances[0] is shared + assert len(shared.transitions) == 2 + + def test_class_level_listener_partial(self): + class ConfigurableListener: + def __init__(self, prefix="default"): + self.prefix = prefix + self.messages = [] + + def after_transition(self, event, source, target): + self.messages.append(f"{self.prefix}: {source.id} -> {target.id}") + + class MyChart(StateChart): + listeners = [partial(ConfigurableListener, prefix="custom")] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart() + sm.send("go") + + listener = sm._class_listener_instances[0] + assert listener.prefix == "custom" + assert listener.messages == ["custom: s1 -> s2"] + + def test_class_level_listener_lambda(self): + class SimpleListener: + def __init__(self, tag): + self.tag = tag + + class MyChart(StateChart): + listeners = [lambda: SimpleListener("from_lambda")] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart() + assert sm._class_listener_instances[0].tag == "from_lambda" + + def test_runtime_listeners_merge_with_class_level(self): + class MyChart(StateChart): + listeners = [RecordingListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + runtime_listener = RecordingListener() + sm = MyChart(listeners=[runtime_listener]) + + sm.send("go") + + # Class-level listener should have recorded + class_listener = sm._class_listener_instances[0] + assert class_listener.transitions == [("go", "s1", "s2")] + + # Runtime listener should also have recorded + assert runtime_listener.transitions == [("go", "s1", "s2")] + + +class TestClassListenerInheritance: + def test_child_extends_parent_listeners(self): + class ParentListener: + pass + + class ChildListener: + pass + + class Parent(StateChart): + listeners = [ParentListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + class Child(Parent): + listeners = [ChildListener] + + sm = Child() + assert len(sm._class_listener_instances) == 2 + assert isinstance(sm._class_listener_instances[0], ParentListener) + assert isinstance(sm._class_listener_instances[1], ChildListener) + + def test_child_replaces_parent_listeners(self): + class ParentListener: + pass + + class ChildListener: + pass + + class Parent(StateChart): + listeners = [ParentListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + class Child(Parent): + listeners_inherit = False + listeners = [ChildListener] + + sm = Child() + assert len(sm._class_listener_instances) == 1 + assert isinstance(sm._class_listener_instances[0], ChildListener) + + def test_grandchild_inherits_full_chain(self): + class L1: + pass + + class L2: + pass + + class L3: + pass + + class Base(StateChart): + listeners = [L1] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + class Mid(Base): + listeners = [L2] + + class Leaf(Mid): + listeners = [L3] + + sm = Leaf() + assert len(sm._class_listener_instances) == 3 + assert isinstance(sm._class_listener_instances[0], L1) + assert isinstance(sm._class_listener_instances[1], L2) + assert isinstance(sm._class_listener_instances[2], L3) + + def test_no_listeners_declared_inherits_parent(self): + class ParentListener: + pass + + class Parent(StateChart): + listeners = [ParentListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + class Child(Parent): + pass + + sm = Child() + assert len(sm._class_listener_instances) == 1 + assert isinstance(sm._class_listener_instances[0], ParentListener) + + +class TestListenerSetupProtocol: + def test_setup_receives_kwargs(self): + class MyChart(StateChart): + listeners = [SetupListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart(session="my_db_session") + listener = sm._class_listener_instances[0] + assert listener.session == "my_db_session" + + def test_setup_ignores_unknown_kwargs(self): + class MyChart(StateChart): + listeners = [SetupListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart(session="db", unknown_arg="ignored") + listener = sm._class_listener_instances[0] + assert listener.session == "db" + + def test_setup_not_called_on_shared_instances(self): + shared = SetupListener() + + class MyChart(StateChart): + listeners = [shared] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + MyChart(session="db") + # Shared instance should NOT have setup() called + assert shared.session is None + + def test_multiple_listeners_with_different_deps(self): + class DBListener: + def __init__(self): + self.session = None + + def setup(self, sm, session=None, **kwargs): + self.session = session + + class CacheListener: + def __init__(self): + self.redis = None + + def setup(self, sm, redis=None, **kwargs): + self.redis = redis + + class MyChart(StateChart): + listeners = [DBListener, CacheListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart(session="db_conn", redis="redis_conn") + db = sm._class_listener_instances[0] + cache = sm._class_listener_instances[1] + assert db.session == "db_conn" + assert cache.redis == "redis_conn" + + def test_setup_receives_sm_instance(self): + class IntrospectiveListener: + def __init__(self): + self.sm = None + + def setup(self, sm, **kwargs): + self.sm = sm + + class MyChart(StateChart): + listeners = [IntrospectiveListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart() + listener = sm._class_listener_instances[0] + assert listener.sm is sm + + def test_setup_optional_kwargs_default_to_none(self): + class MyChart(StateChart): + listeners = [SetupListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart() # No session kwarg provided + listener = sm._class_listener_instances[0] + assert listener.session is None + + def test_setup_required_kwarg_missing_raises_error(self): + class StrictListener: + def setup(self, sm, session): + self.session = session + + class MyChart(StateChart): + listeners = [StrictListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + with pytest.raises(TypeError, match="Error calling setup.*StrictListener"): + MyChart() + + def test_setup_required_kwarg_provided(self): + class StrictListener: + def setup(self, sm, session): + self.session = session + + class MyChart(StateChart): + listeners = [StrictListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart(session="db_conn") + assert sm._class_listener_instances[0].session == "db_conn" + + +class TestListenerValidation: + def test_rejects_none_in_listeners(self): + with pytest.raises(InvalidDefinition, match="Invalid entry"): + + class MyChart(StateChart): + listeners = [None] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + def test_rejects_string_in_listeners(self): + with pytest.raises(InvalidDefinition, match="Invalid entry"): + + class MyChart(StateChart): + listeners = ["not_a_listener"] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + def test_rejects_number_in_listeners(self): + with pytest.raises(InvalidDefinition, match="Invalid entry"): + + class MyChart(StateChart): + listeners = [42] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + def test_rejects_bool_in_listeners(self): + with pytest.raises(InvalidDefinition, match="Invalid entry"): + + class MyChart(StateChart): + listeners = [True] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + +class _PickleChart(StateChart): + listeners = [RecordingListener] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + +class _PickleMultiStepChart(StateChart): + listeners = [RecordingListener] + + s1 = State(initial=True) + s2 = State() + s3 = State(final=True) + step1 = s1.to(s2) + step2 = s2.to(s3) + + +class TestListenerSerialization: + def test_pickle_with_class_listeners(self): + sm = _PickleChart() + sm.send("go") + + data = pickle.dumps(sm) + sm2 = pickle.loads(data) + + # Class listener instances are preserved through serialization + assert len(sm2._class_listener_instances) == 1 + assert sm2._class_listener_instances[0].transitions == [("go", "s1", "s2")] + assert "s2" in sm2.configuration_values + + def test_pickle_does_not_duplicate_class_listeners(self): + sm = _PickleChart() + assert len(sm.active_listeners) == 1 + + data = pickle.dumps(sm) + sm2 = pickle.loads(data) + + # Must not duplicate class listeners after deserialization + assert len(sm2.active_listeners) == 1 + + def test_pickle_with_runtime_listeners(self): + runtime = RecordingListener() + sm = _PickleMultiStepChart(listeners=[runtime]) + sm.send("step1") + + data = pickle.dumps(sm) + sm2 = pickle.loads(data) + + # After deserialization, both class and runtime listeners are re-registered + assert "s2" in sm2.configuration_values + sm2.send("step2") + assert "s3" in sm2.configuration_values + + +class TestEmptyClassListeners: + def test_no_listeners_attribute(self): + class MyChart(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart() + assert sm._class_listener_instances == [] + + def test_empty_listeners_list(self): + class MyChart(StateChart): + listeners = [] + + s1 = State(initial=True) + s2 = State(final=True) + go = s1.to(s2) + + sm = MyChart() + assert sm._class_listener_instances == [] From 8eda64a8e4cbcc5d526b05e05df55e387bcc5673 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Wed, 18 Feb 2026 14:23:44 -0300 Subject: [PATCH 2/2] refactor: remove _class_listener_instances, use active_listeners API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `_class_listener_instances` attribute was redundant — `_listeners` already tracks all listeners (class-level + runtime + add_listener). Remove the internal attribute and migrate all tests to use the public `active_listeners` property instead. --- statemachine/statemachine.py | 4 +- tests/test_class_listeners.py | 73 ++++++++++++++++++----------------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index 1f1fb1b0..54fa662f 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -157,8 +157,8 @@ def __init__( if self._abstract: raise InvalidDefinition(_("There are no states or transitions.")) - self._class_listener_instances = self._resolve_class_listeners(**kwargs) - all_listeners = self._class_listener_instances + (listeners or []) + class_listener_instances = self._resolve_class_listeners(**kwargs) + all_listeners = class_listener_instances + (listeners or []) self._register_callbacks(all_listeners) # Activate the initial state, this only works if the outer scope is sync code. diff --git a/tests/test_class_listeners.py b/tests/test_class_listeners.py index fb257ea8..fc1aad08 100644 --- a/tests/test_class_listeners.py +++ b/tests/test_class_listeners.py @@ -47,13 +47,13 @@ class MyChart(StateChart): sm1.send("go") # Each SM gets its own listener instance - assert len(sm1._class_listener_instances) == 1 - assert len(sm2._class_listener_instances) == 1 - assert sm1._class_listener_instances[0] is not sm2._class_listener_instances[0] + assert len(sm1.active_listeners) == 1 + assert len(sm2.active_listeners) == 1 + assert sm1.active_listeners[0] is not sm2.active_listeners[0] # Only sm1 should have the transition recorded - assert sm1._class_listener_instances[0].transitions == [("go", "s1", "s2")] - assert sm2._class_listener_instances[0].transitions == [] + assert sm1.active_listeners[0].transitions == [("go", "s1", "s2")] + assert sm2.active_listeners[0].transitions == [] def test_class_level_listener_shared_instance(self): shared = RecordingListener() @@ -72,8 +72,8 @@ class MyChart(StateChart): sm2.send("go") # Both SMs share the same listener instance - assert sm1._class_listener_instances[0] is shared - assert sm2._class_listener_instances[0] is shared + assert sm1.active_listeners[0] is shared + assert sm2.active_listeners[0] is shared assert len(shared.transitions) == 2 def test_class_level_listener_partial(self): @@ -95,7 +95,7 @@ class MyChart(StateChart): sm = MyChart() sm.send("go") - listener = sm._class_listener_instances[0] + listener = sm.active_listeners[0] assert listener.prefix == "custom" assert listener.messages == ["custom: s1 -> s2"] @@ -112,7 +112,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart() - assert sm._class_listener_instances[0].tag == "from_lambda" + assert sm.active_listeners[0].tag == "from_lambda" def test_runtime_listeners_merge_with_class_level(self): class MyChart(StateChart): @@ -127,12 +127,14 @@ class MyChart(StateChart): sm.send("go") - # Class-level listener should have recorded - class_listener = sm._class_listener_instances[0] - assert class_listener.transitions == [("go", "s1", "s2")] + assert len(sm.active_listeners) == 2 + + # Both listeners should have recorded + for listener in sm.active_listeners: + assert listener.transitions == [("go", "s1", "s2")] - # Runtime listener should also have recorded - assert runtime_listener.transitions == [("go", "s1", "s2")] + # Runtime listener is the one we passed in + assert runtime_listener in sm.active_listeners class TestClassListenerInheritance: @@ -154,9 +156,9 @@ class Child(Parent): listeners = [ChildListener] sm = Child() - assert len(sm._class_listener_instances) == 2 - assert isinstance(sm._class_listener_instances[0], ParentListener) - assert isinstance(sm._class_listener_instances[1], ChildListener) + assert len(sm.active_listeners) == 2 + assert isinstance(sm.active_listeners[0], ParentListener) + assert isinstance(sm.active_listeners[1], ChildListener) def test_child_replaces_parent_listeners(self): class ParentListener: @@ -177,8 +179,8 @@ class Child(Parent): listeners = [ChildListener] sm = Child() - assert len(sm._class_listener_instances) == 1 - assert isinstance(sm._class_listener_instances[0], ChildListener) + assert len(sm.active_listeners) == 1 + assert isinstance(sm.active_listeners[0], ChildListener) def test_grandchild_inherits_full_chain(self): class L1: @@ -204,10 +206,10 @@ class Leaf(Mid): listeners = [L3] sm = Leaf() - assert len(sm._class_listener_instances) == 3 - assert isinstance(sm._class_listener_instances[0], L1) - assert isinstance(sm._class_listener_instances[1], L2) - assert isinstance(sm._class_listener_instances[2], L3) + assert len(sm.active_listeners) == 3 + assert isinstance(sm.active_listeners[0], L1) + assert isinstance(sm.active_listeners[1], L2) + assert isinstance(sm.active_listeners[2], L3) def test_no_listeners_declared_inherits_parent(self): class ParentListener: @@ -224,8 +226,8 @@ class Child(Parent): pass sm = Child() - assert len(sm._class_listener_instances) == 1 - assert isinstance(sm._class_listener_instances[0], ParentListener) + assert len(sm.active_listeners) == 1 + assert isinstance(sm.active_listeners[0], ParentListener) class TestListenerSetupProtocol: @@ -238,7 +240,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart(session="my_db_session") - listener = sm._class_listener_instances[0] + listener = sm.active_listeners[0] assert listener.session == "my_db_session" def test_setup_ignores_unknown_kwargs(self): @@ -250,7 +252,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart(session="db", unknown_arg="ignored") - listener = sm._class_listener_instances[0] + listener = sm.active_listeners[0] assert listener.session == "db" def test_setup_not_called_on_shared_instances(self): @@ -290,8 +292,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart(session="db_conn", redis="redis_conn") - db = sm._class_listener_instances[0] - cache = sm._class_listener_instances[1] + db, cache = sm.active_listeners assert db.session == "db_conn" assert cache.redis == "redis_conn" @@ -311,7 +312,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart() - listener = sm._class_listener_instances[0] + listener = sm.active_listeners[0] assert listener.sm is sm def test_setup_optional_kwargs_default_to_none(self): @@ -323,7 +324,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart() # No session kwarg provided - listener = sm._class_listener_instances[0] + listener = sm.active_listeners[0] assert listener.session is None def test_setup_required_kwarg_missing_raises_error(self): @@ -354,7 +355,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart(session="db_conn") - assert sm._class_listener_instances[0].session == "db_conn" + assert sm.active_listeners[0].session == "db_conn" class TestListenerValidation: @@ -426,8 +427,8 @@ def test_pickle_with_class_listeners(self): sm2 = pickle.loads(data) # Class listener instances are preserved through serialization - assert len(sm2._class_listener_instances) == 1 - assert sm2._class_listener_instances[0].transitions == [("go", "s1", "s2")] + assert len(sm2.active_listeners) == 1 + assert sm2.active_listeners[0].transitions == [("go", "s1", "s2")] assert "s2" in sm2.configuration_values def test_pickle_does_not_duplicate_class_listeners(self): @@ -462,7 +463,7 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart() - assert sm._class_listener_instances == [] + assert sm.active_listeners == [] def test_empty_listeners_list(self): class MyChart(StateChart): @@ -473,4 +474,4 @@ class MyChart(StateChart): go = s1.to(s2) sm = MyChart() - assert sm._class_listener_instances == [] + assert sm.active_listeners == []