diff --git a/changelog.d/+saga-effects-expansion.added.md b/changelog.d/+saga-effects-expansion.added.md new file mode 100644 index 0000000..0a91d4b --- /dev/null +++ b/changelog.d/+saga-effects-expansion.added.md @@ -0,0 +1 @@ +Added Race (first-wins with loser cancellation), All (wait-all with fail-fast), Take (pause until action dispatched), and Debounce (cancel-and-restart timer) saga effects. Fixed Python 2 exception syntax bug in pipeline handler introspection. Added gateway test suite covering namespacing, routing, proxying, idle reaping, and error handling. diff --git a/src/milo/__init__.py b/src/milo/__init__.py index 0276b19..c97eecd 100644 --- a/src/milo/__init__.py +++ b/src/milo/__init__.py @@ -28,6 +28,10 @@ def __getattr__(name: str): "Retry": "_types", "Timeout": "_types", "TryCall": "_types", + "Race": "_types", + "All": "_types", + "Take": "_types", + "Debounce": "_types", "Cmd": "_types", "Batch": "_types", "Sequence": "_types", @@ -154,6 +158,7 @@ def _Py_mod_gil() -> int: # noqa: N802 "CLI", "DEFAULT_THEME", "Action", + "All", "App", "AppError", "AppStatus", @@ -168,6 +173,7 @@ def _Py_mod_gil() -> int: # noqa: N802 "ConfigSpec", "Context", "CycleError", + "Debounce", "Delay", "Description", "DevServer", @@ -213,6 +219,7 @@ def _Py_mod_gil() -> int: # noqa: N802 "PromptDef", "Put", "Quit", + "Race", "ReducerResult", "RenderTarget", "RequestLog", @@ -225,6 +232,7 @@ def _Py_mod_gil() -> int: # noqa: N802 "SpecialKey", "StateError", "Store", + "Take", "ThemeProxy", "ThemeStyle", "TickCmd", diff --git a/src/milo/_types.py b/src/milo/_types.py index b329689..916614d 100644 --- a/src/milo/_types.py +++ b/src/milo/_types.py @@ -269,6 +269,78 @@ class TryCall: kwargs: dict = field(default_factory=dict) +@dataclass(frozen=True, slots=True) +class Race: + """Run multiple sagas concurrently, return the first result. + + Losers are cancelled via their cancel events as soon as a winner + completes. If all racers fail, the first error is thrown into + the parent saga:: + + winner = yield Race(sagas=(fetch_primary(), fetch_fallback())) + + Raises ``StateError`` if *sagas* is empty. + """ + + sagas: tuple + + +@dataclass(frozen=True, slots=True) +class All: + """Run multiple sagas concurrently, wait for all to complete. + + Returns a tuple of results in the same order as the input sagas. + Fail-fast: if any saga raises, remaining sagas are cancelled and + the error is thrown into the parent:: + + a, b = yield All(sagas=(fetch_users(), fetch_roles())) + + An empty tuple returns ``()`` immediately. + """ + + sagas: tuple + + +@dataclass(frozen=True, slots=True) +class Take: + """Pause the saga until a matching action is dispatched. + + Waits for *future* actions only — actions dispatched before the + Take is yielded are not matched. Returns the full ``Action`` + object so the saga can inspect both type and payload:: + + action = yield Take("USER_CONFIRMED") + name = action.payload["name"] + + An optional *timeout* (in seconds) raises ``TimeoutError`` if the + action is not dispatched in time. + """ + + action_type: str + timeout: float | None = None + + +@dataclass(frozen=True, slots=True) +class Debounce: + """Delay-then-fork: start a timer, fork *saga* when it expires. + + If the parent saga yields another ``Debounce`` before the timer + fires, the previous timer is cancelled and restarted. The parent + continues immediately (non-blocking):: + + # In a keystroke handler saga: + while True: + key = yield Take("@@KEY") + yield Debounce(seconds=0.3, saga=search_saga) + + The debounced saga runs independently; use ``Take`` if the parent + needs the result. + """ + + seconds: float + saga: Callable + + # --------------------------------------------------------------------------- # Commands (lightweight alternative to sagas) # --------------------------------------------------------------------------- diff --git a/src/milo/state.py b/src/milo/state.py index 7578485..37d1196 100644 --- a/src/milo/state.py +++ b/src/milo/state.py @@ -14,17 +14,21 @@ from milo._errors import ErrorCode, StateError from milo._types import ( Action, + All, Batch, Call, Cmd, + Debounce, Delay, Fork, Put, Quit, + Race, ReducerResult, Retry, Select, Sequence, + Take, TickCmd, Timeout, TryCall, @@ -61,6 +65,9 @@ def __init__( self._quit = threading.Event() self._exit_code = 0 self._view_state = None + # Take effect: waiters keyed by action_type + # Each entry: list of (Event, result_box_list) tuples + self._action_waiters: dict[str, list[tuple[threading.Event, list]]] = {} # Build middleware chain self._dispatch_fn = self._base_dispatch @@ -128,6 +135,13 @@ def _base_dispatch(self, action: Action) -> None: } ) + # Notify Take waiters (inside lock to avoid missed actions) + waiters = self._action_waiters.pop(action.type, None) + if waiters: + for event, result_box in waiters: + result_box.append(action) + event.set() + # Store latest view state for renderer to pick up if view is not None: self._view_state = view @@ -174,6 +188,7 @@ def _run_saga(self, saga: Any, cancel: threading.Event | None = None) -> None: """ if cancel is None: cancel = threading.Event() + pending_debounce: list = [] # [(timer, child_cancel)] — at most one entry try: effect = next(saga) while True: @@ -226,6 +241,100 @@ def _run_saga(self, saga: Any, cancel: threading.Event | None = None) -> None: effect = saga.send((result, None)) except Exception as call_err: effect = saga.send((None, call_err)) + case Race(child_sagas): + if not child_sagas: + raise StateError(ErrorCode.STA_SAGA, "Race requires at least one saga") + try: + result = self._execute_race(child_sagas, cancel) + except Exception as race_err: + effect = saga.throw(race_err) + else: + effect = saga.send(result) + case All(child_sagas): + if not child_sagas: + effect = saga.send(()) + else: + try: + results = self._execute_all(child_sagas, cancel) + except Exception as all_err: + effect = saga.throw(all_err) + else: + effect = saga.send(results) + case Take(action_type, timeout): + waiter_event = threading.Event() + result_box: list = [] + with self._lock: + self._action_waiters.setdefault(action_type, []).append( + (waiter_event, result_box) + ) + # Wait outside the lock in short intervals so cancellation + # can be checked promptly while still honoring timeout. + wait_interval = 0.1 + deadline = None if timeout is None else time.monotonic() + timeout + while not waiter_event.is_set(): + if cancel.is_set(): + break + if deadline is None: + current_timeout = wait_interval + else: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + current_timeout = min(wait_interval, remaining) + waiter_event.wait(timeout=current_timeout) + if cancel.is_set(): + # Clean up waiter if not consumed + with self._lock: + entries = self._action_waiters.get(action_type, []) + for i, (ev, _) in enumerate(entries): + if ev is waiter_event: + entries.pop(i) + break + if not entries and action_type in self._action_waiters: + del self._action_waiters[action_type] + continue # Loop back to cancellation check + if result_box: + effect = saga.send(result_box[0]) + else: + # Timeout expired — clean up waiter + with self._lock: + entries = self._action_waiters.get(action_type, []) + for i, (ev, _) in enumerate(entries): + if ev is waiter_event: + entries.pop(i) + break + if not entries and action_type in self._action_waiters: + del self._action_waiters[action_type] + try: + effect = saga.throw( + TimeoutError( + f"Take('{action_type}') timed out after {timeout}s" + ) + ) + except StopIteration: + return + case Debounce(seconds, inner_saga): + # Cancel any pending debounce timer from a previous yield + if pending_debounce: + old_timer, old_cancel = pending_debounce[0] + old_timer.cancel() + old_cancel.set() + pending_debounce.clear() + child_cancel = threading.Event() + + def _debounce_fire( + s=inner_saga, + cc=child_cancel, + store=self, + ): + if not cc.is_set(): + store._executor.submit(store._run_saga, s(), cc) + + timer = threading.Timer(seconds, _debounce_fire) + timer.daemon = True + timer.start() + pending_debounce.append((timer, child_cancel)) + effect = next(saga) case _: raise StateError( ErrorCode.STA_SAGA, @@ -244,6 +353,12 @@ def _run_saga(self, saga: Any, cancel: threading.Event | None = None) -> None: ) except Exception: _logger.debug("Failed to dispatch @@SAGA_ERROR", exc_info=True) + finally: + # Cancel any pending debounce timer on saga exit + if pending_debounce: + old_timer, old_cancel = pending_debounce[0] + old_timer.cancel() + old_cancel.set() def _execute_timeout(self, effect: Call | Retry, seconds: float) -> Any: """Execute a blocking effect with a timeout deadline. @@ -285,6 +400,142 @@ def _execute_effect(effect: Call | Retry) -> Any: f"Cannot execute effect type: {type(effect).__name__}", ) + def _run_saga_capturing( + self, + saga: Any, + cancel: threading.Event, + result_box: list, + error_box: list, + done: threading.Event, + ) -> None: + """Step through a saga via ``_run_saga``, capturing the return value. + + Wraps *saga* in a thin ``yield from`` generator so that + ``_run_saga`` handles **all** effect types (including nested + Race/All/Take/Debounce). On success the return value is + appended to *result_box*; on error the exception goes into + *error_box*. *done* is set in all cases. + """ + + def _wrapper(): + try: + result = yield from saga + result_box.append(result) + except Exception as e: + error_box.append(e) + + self._run_saga(_wrapper(), cancel) + done.set() + + def _execute_race(self, child_sagas: tuple, parent_cancel: threading.Event) -> Any: + """Run sagas concurrently, return the first result. Cancel losers.""" + condition = threading.Condition() + child_cancels: list[threading.Event] = [] + child_dones: list[threading.Event] = [] + child_results: list[list] = [] + child_errors: list[list] = [] + + for child_saga in child_sagas: + child_cancel = threading.Event() + child_done = threading.Event() + result_box: list[Any] = [] + error_box: list[Exception] = [] + child_cancels.append(child_cancel) + child_dones.append(child_done) + child_results.append(result_box) + child_errors.append(error_box) + + def _notify_wrapper( + saga=child_saga, + cancel=child_cancel, + rb=result_box, + eb=error_box, + done=child_done, + ): + self._run_saga_capturing(saga, cancel, rb, eb, done) + with condition: + condition.notify_all() + + self._executor.submit(_notify_wrapper) + + # Wait for first completion or parent cancellation + with condition: + while True: + if parent_cancel.is_set(): + for cc in child_cancels: + cc.set() + raise StateError(ErrorCode.STA_SAGA, "Race cancelled") + for i, done in enumerate(child_dones): + if done.is_set(): + # Cancel all others + for cc in child_cancels: + cc.set() + if child_results[i]: + return child_results[i][0] + if child_errors[i]: + raise child_errors[i][0] + # All done — re-check results (a child may have finished + # between the per-child is_set() check and here). + if all(d.is_set() for d in child_dones): + for i2 in range(len(child_dones)): + if child_results[i2]: + for cc in child_cancels: + cc.set() + return child_results[i2][0] + if child_errors[i2]: + raise child_errors[i2][0] + return None + condition.wait(timeout=0.05) + + def _execute_all(self, child_sagas: tuple, parent_cancel: threading.Event) -> tuple: + """Run sagas concurrently, wait for all. Fail-fast on first error.""" + condition = threading.Condition() + child_cancels: list[threading.Event] = [] + child_dones: list[threading.Event] = [] + child_results: list[list] = [] + child_errors: list[list] = [] + + for child_saga in child_sagas: + child_cancel = threading.Event() + child_done = threading.Event() + result_box: list[Any] = [] + error_box: list[Exception] = [] + child_cancels.append(child_cancel) + child_dones.append(child_done) + child_results.append(result_box) + child_errors.append(error_box) + + def _notify_wrapper( + saga=child_saga, + cancel=child_cancel, + rb=result_box, + eb=error_box, + done=child_done, + ): + self._run_saga_capturing(saga, cancel, rb, eb, done) + with condition: + condition.notify_all() + + self._executor.submit(_notify_wrapper) + + # Wait for all to complete or first failure + with condition: + while True: + if parent_cancel.is_set(): + for cc in child_cancels: + cc.set() + raise StateError(ErrorCode.STA_SAGA, "All cancelled") + # Check for errors (fail-fast) + for i, done in enumerate(child_dones): + if done.is_set() and child_errors[i]: + for cc in child_cancels: + cc.set() + raise child_errors[i][0] + # Check if all done + if all(d.is_set() for d in child_dones): + return tuple(rb[0] if rb else None for rb in child_results) + condition.wait(timeout=0.05) + def _exec_cmd(self, cmd: Any) -> None: """Execute a Cmd, Batch, Sequence, or TickCmd.""" match cmd: diff --git a/tests/test_dev.py b/tests/test_dev.py index 8419571..bc633dc 100644 --- a/tests/test_dev.py +++ b/tests/test_dev.py @@ -158,9 +158,9 @@ def test_separate_batches_after_debounce(self): batcher.set_callback(lambda paths: flushed.append(list(paths))) batcher.add([Path("first.txt")]) - time.sleep(0.15) + time.sleep(0.3) batcher.add([Path("second.txt")]) - time.sleep(0.15) + time.sleep(0.3) assert len(flushed) == 2 diff --git a/tests/test_effects.py b/tests/test_effects.py index f299cce..1979c2e 100644 --- a/tests/test_effects.py +++ b/tests/test_effects.py @@ -7,7 +7,21 @@ import pytest -from milo._types import Action, Call, Delay, Fork, Put, Retry, Select, Timeout, TryCall +from milo._types import ( + Action, + All, + Call, + Debounce, + Delay, + Fork, + Put, + Race, + Retry, + Select, + Take, + Timeout, + TryCall, +) class TestSagaStepping: @@ -500,3 +514,523 @@ def reducer(state, action): assert "@@SAGA_CANCELLED" in actions assert not done.is_set() + + +class TestRaceEffect: + def test_race_dataclass(self): + r = Race(sagas=(lambda: None, lambda: None)) + assert len(r.sagas) == 2 + + def test_race_dataclass_frozen(self): + r = Race(sagas=()) + with pytest.raises(AttributeError): + r.sagas = () # type: ignore[misc] + + def test_race_basic(self): + """First saga to complete wins.""" + from milo.state import Store + + results = [] + + def fast(): + yield Delay(seconds=0.05) + return "fast" + + def slow(): + yield Delay(seconds=5.0) + return "slow" + + def parent(): + winner = yield Race(sagas=(fast(), slow())) + results.append(winner) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert results == ["fast"] + + def test_race_loser_cancelled(self): + """Losing sagas should be cancelled.""" + from milo.state import Store + + actions = [] + + def fast(): + yield Delay(seconds=0.05) + return "fast" + + def slow(): + yield Delay(seconds=0.05) + yield Delay(seconds=10.0) # Should be cancelled before this completes + return "slow" + + def parent(): + yield Race(sagas=(fast(), slow())) + + def reducer(state, action): + actions.append(action.type) + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert "@@SAGA_CANCELLED" in actions + + def test_race_with_failing_saga(self): + """If the first to finish raises, that error propagates.""" + from milo.state import Store + + errors = [] + + def fail_fast(): + yield Delay(seconds=0.05) + raise ValueError("boom") + + def slow(): + yield Delay(seconds=5.0) + return "slow" + + def parent(): + try: + yield Race(sagas=(fail_fast(), slow())) + except ValueError as e: + errors.append(str(e)) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + # The error from fail_fast propagates as @@SAGA_ERROR since + # it's thrown into the parent which catches it + # But the capturing wrapper catches and stores the error + # The parent saga gets the error thrown into it + assert errors == ["boom"] + + def test_race_single_saga(self): + """Race with one saga just returns its result.""" + from milo.state import Store + + results = [] + + def only(): + yield Delay(seconds=0.05) + return "only" + + def parent(): + winner = yield Race(sagas=(only(),)) + results.append(winner) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.3) + store._executor.shutdown(wait=True) + + assert results == ["only"] + + def test_race_empty_raises(self): + """Race with no sagas raises StateError.""" + from milo.state import Store + + errors = [] + + def parent(): + yield Race(sagas=()) + + def reducer(state, action): + if action.type == "@@SAGA_ERROR": + errors.append(action.payload) + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + store._executor.shutdown(wait=True) + + assert len(errors) == 1 + assert "at least one saga" in errors[0]["error"] + + +class TestAllEffect: + def test_all_dataclass(self): + a = All(sagas=(lambda: None, lambda: None)) + assert len(a.sagas) == 2 + + def test_all_dataclass_frozen(self): + a = All(sagas=()) + with pytest.raises(AttributeError): + a.sagas = () # type: ignore[misc] + + def test_all_effect_basic(self): + """All waits for all sagas and returns results in order.""" + from milo.state import Store + + results = [] + + def saga_a(): + yield Delay(seconds=0.05) + return "a" + + def saga_b(): + yield Delay(seconds=0.1) + return "b" + + def parent(): + result = yield All(sagas=(saga_a(), saga_b())) + results.append(result) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert results == [("a", "b")] + + def test_all_effect_one_failure_cancels_rest(self): + """If one saga fails, the rest are cancelled and error propagates.""" + from milo.state import Store + + errors = [] + actions = [] + + def good(): + yield Delay(seconds=5.0) + return "good" + + def bad(): + yield Delay(seconds=0.05) + raise ValueError("bad") + + def parent(): + try: + yield All(sagas=(good(), bad())) + except ValueError as e: + errors.append(str(e)) + + def reducer(state, action): + actions.append(action.type) + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert errors == ["bad"] + assert "@@SAGA_CANCELLED" in actions + + def test_all_effect_empty(self): + """All with empty tuple returns empty tuple immediately.""" + from milo.state import Store + + results = [] + + def parent(): + result = yield All(sagas=()) + results.append(result) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + store._executor.shutdown(wait=True) + + assert results == [()] + + def test_all_effect_single_saga(self): + """All with one saga wraps result in 1-tuple.""" + from milo.state import Store + + results = [] + + def only(): + yield Delay(seconds=0.05) + return "only" + + def parent(): + result = yield All(sagas=(only(),)) + results.append(result) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.3) + store._executor.shutdown(wait=True) + + assert results == [("only",)] + + def test_all_effect_preserves_order(self): + """Results are ordered by input position, not completion time.""" + from milo.state import Store + + results = [] + + def slow(): + yield Delay(seconds=0.15) + return "slow" + + def fast(): + yield Delay(seconds=0.05) + return "fast" + + def parent(): + result = yield All(sagas=(slow(), fast())) + results.append(result) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert results == [("slow", "fast")] + + +class TestTakeEffect: + def test_take_dataclass(self): + t = Take(action_type="@@KEY") + assert t.action_type == "@@KEY" + assert t.timeout is None + + def test_take_dataclass_with_timeout(self): + t = Take(action_type="@@KEY", timeout=5.0) + assert t.timeout == 5.0 + + def test_take_dataclass_frozen(self): + t = Take(action_type="@@KEY") + with pytest.raises(AttributeError): + t.action_type = "other" # type: ignore[misc] + + def test_take_basic(self): + """Take pauses saga until matching action is dispatched.""" + from milo.state import Store + + results = [] + + def saga(): + action = yield Take("USER_CONFIRMED") + results.append(action.payload) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(saga()) + time.sleep(0.1) + # Saga should be blocked + assert results == [] + + store.dispatch(Action("USER_CONFIRMED", payload="yes")) + time.sleep(0.1) + store._executor.shutdown(wait=True) + + assert results == ["yes"] + + def test_take_with_timeout(self): + """Take with timeout returns action if dispatched in time.""" + from milo.state import Store + + results = [] + + def saga(): + action = yield Take("FAST_ACTION", timeout=2.0) + results.append(action.payload) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(saga()) + time.sleep(0.05) + store.dispatch(Action("FAST_ACTION", payload="got it")) + time.sleep(0.1) + store._executor.shutdown(wait=True) + + assert results == ["got it"] + + def test_take_timeout_fires(self): + """Take raises TimeoutError when action isn't dispatched in time.""" + from milo.state import Store + + errors = [] + + def saga(): + try: + yield Take("NEVER_COMES", timeout=0.1) + except TimeoutError as e: + errors.append(str(e)) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(saga()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert len(errors) == 1 + assert "NEVER_COMES" in errors[0] + assert "timed out" in errors[0] + + def test_take_ignores_past_actions(self): + """Take waits for future actions only, not already-dispatched ones.""" + from milo.state import Store + + results = [] + + def saga(): + # Dispatch happens before Take + yield Put(Action("EARLY_ACTION", payload="early")) + yield Delay(seconds=0.05) + # Now take — should NOT match the already-dispatched action + try: + yield Take("EARLY_ACTION", timeout=0.15) + results.append("matched") + except TimeoutError: + results.append("timed_out") + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(saga()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert results == ["timed_out"] + + def test_take_multiple_waiters(self): + """Multiple sagas can Take the same action type.""" + from milo.state import Store + + results = [] + + def waiter(name): + action = yield Take("SHARED_EVENT") + results.append((name, action.payload)) + + def reducer(state, action): + return state or 0 + + store = Store(reducer, None) + store.run_saga(waiter("a")) + store.run_saga(waiter("b")) + time.sleep(0.1) + store.dispatch(Action("SHARED_EVENT", payload="hello")) + time.sleep(0.2) + store._executor.shutdown(wait=True) + + assert len(results) == 2 + assert ("a", "hello") in results + assert ("b", "hello") in results + + +class TestDebounceEffect: + def test_debounce_dataclass(self): + d = Debounce(seconds=0.3, saga=lambda: None) + assert d.seconds == 0.3 + + def test_debounce_dataclass_frozen(self): + d = Debounce(seconds=0.3, saga=lambda: None) + with pytest.raises(AttributeError): + d.seconds = 1.0 # type: ignore[misc] + + def test_debounce_basic(self): + """Debounce fires the inner saga after the delay.""" + from milo.state import Store + + actions = [] + + def inner_saga(): + yield Put(Action("DEBOUNCED_FIRE")) + + def parent(): + yield Debounce(seconds=0.1, saga=inner_saga) + yield Delay(seconds=0.3) # Wait for debounce to fire + + def reducer(state, action): + actions.append(action.type) + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert "DEBOUNCED_FIRE" in actions + + def test_debounce_retrigger_resets_timer(self): + """Re-yielding Debounce cancels previous timer and starts new one.""" + from milo.state import Store + + actions = [] + + def inner_saga(): + yield Put(Action("DEBOUNCED_FIRE")) + + def parent(): + # First debounce — 0.15s + yield Debounce(seconds=0.15, saga=inner_saga) + # Wait a bit, then retrigger before first fires + yield Delay(seconds=0.05) + # Second debounce — resets the timer + yield Debounce(seconds=0.15, saga=inner_saga) + # Wait long enough for second to fire + yield Delay(seconds=0.3) + + def reducer(state, action): + actions.append(action.type) + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.6) + store._executor.shutdown(wait=True) + + # Should fire exactly once (first timer cancelled, second fires) + fire_count = actions.count("DEBOUNCED_FIRE") + assert fire_count == 1 + + def test_debounce_cancelled_on_saga_exit(self): + """Pending debounce is cancelled when parent saga ends.""" + from milo.state import Store + + actions = [] + + def inner_saga(): + yield Put(Action("SHOULD_NOT_FIRE")) + + def parent(): + yield Debounce(seconds=0.3, saga=inner_saga) + # Parent exits immediately — debounce should be cancelled + + def reducer(state, action): + actions.append(action.type) + return state or 0 + + store = Store(reducer, None) + store.run_saga(parent()) + time.sleep(0.5) + store._executor.shutdown(wait=True) + + assert "SHOULD_NOT_FIRE" not in actions diff --git a/tests/test_gateway.py b/tests/test_gateway.py new file mode 100644 index 0000000..fbaf138 --- /dev/null +++ b/tests/test_gateway.py @@ -0,0 +1,467 @@ +"""Tests for the MCP gateway — namespacing, routing, proxying, error handling.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from milo.gateway import ( + GatewayState, + _discover_all, + _GatewayHandler, + _idle_reaper, + _proxy_call, + _proxy_prompt, + _proxy_resource, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_child( + name: str, + tools: list[dict] | None = None, + resources: list[dict] | None = None, + prompts: list[dict] | None = None, +) -> MagicMock: + """Build a mock ChildProcess with canned discovery responses.""" + child = MagicMock() + child.name = name + child.idle_timeout = 300.0 + child.is_idle.return_value = False + + tools = tools or [] + resources = resources or [] + prompts = prompts or [] + + child.fetch_tools.return_value = tools + + def _send_call(method: str, params: dict[str, Any], **kw: Any) -> dict[str, Any]: + if method == "resources/list": + return {"resources": resources} + if method == "prompts/list": + return {"prompts": prompts} + if method == "tools/call": + return {"content": [{"type": "text", "text": f"called {params['name']}"}]} + if method == "resources/read": + return {"contents": [{"uri": params["uri"], "text": "data"}]} + if method == "prompts/get": + return {"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]} + return {} + + child.send_call.side_effect = _send_call + return child + + +def _make_gateway_state( + tools: list[dict] | None = None, + tool_routing: dict | None = None, + resources: list[dict] | None = None, + resource_routing: dict | None = None, + prompts: list[dict] | None = None, + prompt_routing: dict | None = None, +) -> GatewayState: + return GatewayState( + tools=tools or [], + tool_routing=tool_routing or {}, + resources=resources or [], + resource_routing=resource_routing or {}, + prompts=prompts or [], + prompt_routing=prompt_routing or {}, + ) + + +# --------------------------------------------------------------------------- +# Discovery & Namespacing +# --------------------------------------------------------------------------- + + +class TestNamespacing: + def test_tool_namespace_prefixed(self): + """Tools are namespaced as cli_name.tool_name.""" + clis = {"taskman": {"command": ["python", "-m", "taskman", "--mcp"]}} + children = { + "taskman": _make_child( + "taskman", + tools=[{"name": "add", "description": "Add task", "inputSchema": {}}], + ), + } + + state = _discover_all(clis, children) + + assert len(state.tools) == 1 + assert state.tools[0]["name"] == "taskman.add" + assert state.tool_routing["taskman.add"] == ("taskman", "add") + + def test_resource_namespace_prefixed(self): + """Resources URIs are prefixed with cli_name/.""" + clis = {"deployer": {"command": ["deployer", "--mcp"]}} + children = { + "deployer": _make_child( + "deployer", + resources=[{"uri": "milo://stats", "name": "stats"}], + ), + } + + state = _discover_all(clis, children) + + assert len(state.resources) == 1 + assert state.resources[0]["uri"] == "deployer/milo://stats" + assert state.resource_routing["deployer/milo://stats"] == ("deployer", "milo://stats") + + def test_prompt_namespace_prefixed(self): + """Prompts are namespaced as cli_name.prompt_name.""" + clis = {"ghub": {"command": ["ghub", "--mcp"]}} + children = { + "ghub": _make_child( + "ghub", + prompts=[{"name": "review", "description": "Review PR"}], + ), + } + + state = _discover_all(clis, children) + + assert len(state.prompts) == 1 + assert state.prompts[0]["name"] == "ghub.review" + assert state.prompt_routing["ghub.review"] == ("ghub", "review") + + def test_multiple_clis_namespaced(self): + """Tools from multiple CLIs get distinct namespaces.""" + clis = { + "taskman": {"command": ["taskman", "--mcp"]}, + "deployer": {"command": ["deployer", "--mcp"]}, + } + children = { + "taskman": _make_child("taskman", tools=[{"name": "add", "inputSchema": {}}]), + "deployer": _make_child("deployer", tools=[{"name": "deploy", "inputSchema": {}}]), + } + + state = _discover_all(clis, children) + + names = [t["name"] for t in state.tools] + assert "taskman.add" in names + assert "deployer.deploy" in names + + def test_tool_title_auto_generated(self): + """Tools without a title get one from cli_name and description.""" + clis = {"myapp": {"command": ["myapp", "--mcp"]}} + children = { + "myapp": _make_child( + "myapp", + tools=[{"name": "build", "description": "Build project", "inputSchema": {}}], + ), + } + + state = _discover_all(clis, children) + assert state.tools[0]["title"] == "myapp: Build project" + + def test_empty_registry(self): + """Empty CLI registry produces empty state.""" + state = _discover_all({}, {}) + assert state.tools == [] + assert state.resources == [] + assert state.prompts == [] + + def test_discovery_order_deterministic(self): + """Tools appear in CLI registration order, not completion order.""" + clis = { + "aaa": {"command": ["aaa"]}, + "zzz": {"command": ["zzz"]}, + } + children = { + "aaa": _make_child("aaa", tools=[{"name": "x", "inputSchema": {}}]), + "zzz": _make_child("zzz", tools=[{"name": "y", "inputSchema": {}}]), + } + + state = _discover_all(clis, children) + names = [t["name"] for t in state.tools] + assert names == ["aaa.x", "zzz.y"] + + +# --------------------------------------------------------------------------- +# Tool call proxying +# --------------------------------------------------------------------------- + + +class TestProxyCall: + def test_proxy_call_routes_correctly(self): + """Tool call is routed to the right child with original name.""" + child = _make_child("taskman") + children = {"taskman": child} + routing = {"taskman.add": ("taskman", "add")} + + result = _proxy_call( + children, routing, {"name": "taskman.add", "arguments": {"title": "hi"}} + ) + + child.send_call.assert_called_once_with( + "tools/call", {"name": "add", "arguments": {"title": "hi"}} + ) + assert result["content"][0]["text"] == "called add" + + def test_proxy_call_unknown_tool(self): + """Unknown tool name returns isError.""" + result = _proxy_call({}, {}, {"name": "nonexistent.tool", "arguments": {}}) + + assert result["isError"] is True + assert "Unknown tool" in result["content"][0]["text"] + + def test_proxy_call_child_unavailable(self): + """Missing child returns isError.""" + routing = {"taskman.add": ("taskman", "add")} + + result = _proxy_call({}, routing, {"name": "taskman.add", "arguments": {}}) + + assert result["isError"] is True + assert "not available" in result["content"][0]["text"] + + def test_proxy_call_child_error(self): + """Child returning an error is surfaced correctly.""" + child = MagicMock() + child.send_call.return_value = {"error": {"code": -1, "message": "broken"}} + children = {"taskman": child} + routing = {"taskman.add": ("taskman", "add")} + + result = _proxy_call(children, routing, {"name": "taskman.add", "arguments": {}}) + + assert result["isError"] is True + assert "broken" in result["content"][0]["text"] + + +# --------------------------------------------------------------------------- +# Resource proxying +# --------------------------------------------------------------------------- + + +class TestProxyResource: + def test_proxy_resource_routes_correctly(self): + """Resource read is routed with original URI.""" + child = _make_child("deployer") + children = {"deployer": child} + routing = {"deployer/milo://stats": ("deployer", "milo://stats")} + + result = _proxy_resource(children, routing, {"uri": "deployer/milo://stats"}) + + child.send_call.assert_called_once_with("resources/read", {"uri": "milo://stats"}) + assert result["contents"][0]["text"] == "data" + + def test_proxy_resource_unknown_uri(self): + """Unknown URI returns empty contents.""" + result = _proxy_resource({}, {}, {"uri": "unknown://x"}) + assert result["contents"] == [] + + def test_proxy_resource_child_unavailable(self): + """Missing child returns empty contents.""" + routing = {"deployer/milo://stats": ("deployer", "milo://stats")} + result = _proxy_resource({}, routing, {"uri": "deployer/milo://stats"}) + assert result["contents"] == [] + + +# --------------------------------------------------------------------------- +# Prompt proxying +# --------------------------------------------------------------------------- + + +class TestProxyPrompt: + def test_proxy_prompt_routes_correctly(self): + """Prompt get is routed with original name.""" + child = _make_child("ghub") + children = {"ghub": child} + routing = {"ghub.review": ("ghub", "review")} + + result = _proxy_prompt( + children, routing, {"name": "ghub.review", "arguments": {"pr": "123"}} + ) + + child.send_call.assert_called_once_with( + "prompts/get", {"name": "review", "arguments": {"pr": "123"}} + ) + assert result["messages"][0]["role"] == "user" + + def test_proxy_prompt_unknown(self): + """Unknown prompt returns empty messages.""" + result = _proxy_prompt({}, {}, {"name": "nope.x"}) + assert result["messages"] == [] + + def test_proxy_prompt_child_unavailable(self): + """Missing child returns empty messages.""" + routing = {"ghub.review": ("ghub", "review")} + result = _proxy_prompt({}, routing, {"name": "ghub.review"}) + assert result["messages"] == [] + + +# --------------------------------------------------------------------------- +# GatewayHandler +# --------------------------------------------------------------------------- + + +class TestGatewayHandler: + def _make_handler(self) -> tuple[_GatewayHandler, dict]: + clis = {"taskman": {"command": ["taskman", "--mcp"]}} + children = { + "taskman": _make_child( + "taskman", + tools=[{"name": "add", "description": "Add task", "inputSchema": {}}], + ), + } + state = _discover_all(clis, children) + handler = _GatewayHandler(clis, state, children) + return handler, children + + def test_initialize(self): + handler, _ = self._make_handler() + result = handler.initialize({}) + assert result["serverInfo"]["name"] == "milo-gateway" + assert "taskman" in result["instructions"] + + def test_list_tools(self): + handler, _ = self._make_handler() + result = handler.list_tools({}) + names = [t["name"] for t in result["tools"]] + assert "taskman.add" in names + + def test_call_tool(self): + handler, _children = self._make_handler() + result = handler.call_tool({"name": "taskman.add", "arguments": {"title": "test"}}) + assert "called add" in result["content"][0]["text"] + + def test_list_resources(self): + clis = {"myapp": {"command": ["myapp"]}} + children = { + "myapp": _make_child("myapp", resources=[{"uri": "milo://stats", "name": "stats"}]) + } + state = _discover_all(clis, children) + handler = _GatewayHandler(clis, state, children) + + result = handler.list_resources({}) + assert len(result["resources"]) == 1 + + def test_list_prompts(self): + clis = {"myapp": {"command": ["myapp"]}} + children = { + "myapp": _make_child("myapp", prompts=[{"name": "help", "description": "Help"}]) + } + state = _discover_all(clis, children) + handler = _GatewayHandler(clis, state, children) + + result = handler.list_prompts({}) + assert len(result["prompts"]) == 1 + + +# --------------------------------------------------------------------------- +# Idle reaping +# --------------------------------------------------------------------------- + + +class TestIdleReaper: + def test_reap_idle_child(self): + """Idle children get killed.""" + child = _make_child("taskman") + child.is_idle.return_value = True + children = {"taskman": child} + + # First sleep passes (reaper sleeps before checking), second raises + call_count = [0] + + def _sleep_then_stop(_seconds: float) -> None: + call_count[0] += 1 + if call_count[0] > 1: + raise StopIteration + + with patch("milo.gateway.time.sleep", side_effect=_sleep_then_stop): + with pytest.raises(StopIteration): + _idle_reaper(children) + + child.kill.assert_called_once() + + def test_keep_active_child(self): + """Active children are not reaped.""" + child = _make_child("taskman") + child.is_idle.return_value = False + children = {"taskman": child} + + call_count = [0] + + def _sleep_then_stop(_seconds: float) -> None: + call_count[0] += 1 + if call_count[0] > 1: + raise StopIteration + + with patch("milo.gateway.time.sleep", side_effect=_sleep_then_stop): + with pytest.raises(StopIteration): + _idle_reaper(children) + + child.kill.assert_not_called() + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + def test_discovery_tool_failure_continues(self): + """If one CLI fails tool discovery, others still work.""" + clis = { + "good": {"command": ["good"]}, + "bad": {"command": ["bad"]}, + } + good_child = _make_child("good", tools=[{"name": "ok", "inputSchema": {}}]) + bad_child = _make_child("bad") + bad_child.fetch_tools.side_effect = RuntimeError("connection refused") + children = {"good": good_child, "bad": bad_child} + + state = _discover_all(clis, children) + + # Good CLI's tools should still be present + assert len(state.tools) == 1 + assert state.tools[0]["name"] == "good.ok" + + def test_discovery_resource_failure_continues(self): + """If resource discovery fails, tools and prompts still work.""" + clis = {"myapp": {"command": ["myapp"]}} + child = _make_child( + "myapp", + tools=[{"name": "run", "inputSchema": {}}], + prompts=[{"name": "help", "description": "Help"}], + ) + # Override send_call to fail on resources/list + original_side_effect = child.send_call.side_effect + + def _failing_resources(method, params, **kw): + if method == "resources/list": + raise ConnectionError("dead") + return original_side_effect(method, params, **kw) + + child.send_call.side_effect = _failing_resources + children = {"myapp": child} + + state = _discover_all(clis, children) + + assert len(state.tools) == 1 + assert len(state.prompts) == 1 + assert len(state.resources) == 0 + + def test_proxy_call_empty_arguments(self): + """Tool call with missing arguments key defaults to empty dict.""" + child = _make_child("taskman") + children = {"taskman": child} + routing = {"taskman.add": ("taskman", "add")} + + _proxy_call(children, routing, {"name": "taskman.add"}) + + child.send_call.assert_called_once_with("tools/call", {"name": "add", "arguments": {}}) + + def test_proxy_prompt_missing_arguments(self): + """Prompt get with no arguments key defaults to empty dict.""" + child = _make_child("ghub") + children = {"ghub": child} + routing = {"ghub.review": ("ghub", "review")} + + _proxy_prompt(children, routing, {"name": "ghub.review"}) + + child.send_call.assert_called_once_with("prompts/get", {"name": "review", "arguments": {}})