From f6b30b3b1ab6530f528365d8f3695a712e5ef418 Mon Sep 17 00:00:00 2001 From: Micah Villmow <4211002+mvillmow@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:11:02 -0700 Subject: [PATCH] =?UTF-8?q?feat(python):=20fresh=20agent=20list=20in=20adv?= =?UTF-8?q?ance=5Fdag,=20background=20DAG=20scan;=20chore(deps):=20bump=20?= =?UTF-8?q?cnats=203.9.3=E2=86=923.12.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #98 Closes #179 Closes #196 Note: Python layer (#196, #98) may be removed by #432 once that cross-repo work resolves. Co-Authored-By: Claude Sonnet 4.6 --- conanfile.py | 2 +- src/keystone/dag_walker.py | 61 ++++++++++++++ tests/test_dag_walker.py | 166 +++++++++++++++++++++++++++++++++++++ 3 files changed, 228 insertions(+), 1 deletion(-) diff --git a/conanfile.py b/conanfile.py index 2c0c5851..7b5b6090 100644 --- a/conanfile.py +++ b/conanfile.py @@ -12,7 +12,7 @@ class ProjectKeystoneConan(ConanFile): def requirements(self) -> None: self.requires("spdlog/1.12.0") self.requires("concurrentqueue/1.0.4") - self.requires("cnats/3.9.3") + self.requires("cnats/3.12.0") if self.options.with_grpc: self.requires("yaml-cpp/0.8.0") diff --git a/src/keystone/dag_walker.py b/src/keystone/dag_walker.py index 0f273f10..6bb6f9f5 100644 --- a/src/keystone/dag_walker.py +++ b/src/keystone/dag_walker.py @@ -1,9 +1,13 @@ from __future__ import annotations +import asyncio +import logging from typing import Any, Optional from .models import Agent, Task, TERMINAL_STATUSES +logger = logging.getLogger(__name__) + class DAGWalker: """Walks a task DAG and assigns ready tasks to available agents.""" @@ -13,10 +17,13 @@ def __init__( tasks: list[Task], agents: list[Agent], client: Optional[Any] = None, + scan_interval: float = 60.0, ) -> None: self.tasks = tasks self.agents = agents self.client = client + self.scan_interval = scan_interval + self._scan_task: Optional[asyncio.Task[None]] = None def get_available_agents(self) -> list[Agent]: """Return agents that are active, online, and not currently assigned a task.""" @@ -125,6 +132,10 @@ def validate_no_cycles(self) -> bool: async def advance_dag(self) -> list[tuple[Task, Agent]]: """Assign ready tasks to available agents, returning the assignments made. + When a ``client`` is configured and it exposes a ``get_agents()`` coroutine, + the method fetches a fresh agent list before evaluating availability (issue + #196). This prevents double-assignment races caused by a stale cached list. + Raises :exc:`ValueError` if the task DAG contains a cycle or if any task references an unknown dependency ID. Agents are marked busy immediately upon selection so a single call cannot double-assign the same agent even if the @@ -134,6 +145,19 @@ async def advance_dag(self) -> list[tuple[Task, Agent]]: if cycle_path: raise ValueError(f"Cycle detected in task DAG: {cycle_path}") + # Issue #196: refresh the agent list from the client before assigning so + # that stale in-memory state doesn't cause double-assignment. + if self.client is not None and hasattr(self.client, "get_agents"): + try: + fresh_agents = await self.client.get_agents() + if fresh_agents is not None: + self.agents = fresh_agents + except Exception: # noqa: BLE001 + logger.warning( + "advance_dag: get_agents() failed — falling back to cached list", + exc_info=True, + ) + assignments: list[tuple[Task, Agent]] = [] available = self.get_available_agents() @@ -153,3 +177,40 @@ async def advance_dag(self) -> list[tuple[Task, Agent]]: assignments.append((task, agent)) return assignments + + async def _background_scan_loop(self, stop_event: asyncio.Event) -> None: + """Periodically call :meth:`advance_dag` as a safety net (issue #98). + + Uses ``asyncio.wait_for`` on *stop_event* so the loop wakes up + immediately on shutdown rather than sleeping for the full interval. + """ + logger.info( + "DAGWalker background scan started (interval=%.1fs)", self.scan_interval + ) + while not stop_event.is_set(): + try: + await asyncio.wait_for( + stop_event.wait(), timeout=self.scan_interval + ) + # stop_event was set — exit cleanly. + break + except asyncio.TimeoutError: + pass + try: + await self.advance_dag() + except Exception: # noqa: BLE001 + logger.exception("DAGWalker background scan failed — continuing") + logger.info("DAGWalker background scan stopped") + + def start_background_scan( + self, stop_event: asyncio.Event + ) -> "asyncio.Task[None]": + """Schedule a background scan loop and return the created task (issue #98). + + The caller is responsible for cancelling or awaiting the returned task on + shutdown. Passing *stop_event* allows a graceful, prompt exit. + """ + self._scan_task = asyncio.create_task( + self._background_scan_loop(stop_event), name="dag-background-scan" + ) + return self._scan_task diff --git a/tests/test_dag_walker.py b/tests/test_dag_walker.py index 7d2e3065..a073cae1 100644 --- a/tests/test_dag_walker.py +++ b/tests/test_dag_walker.py @@ -1,6 +1,7 @@ """Tests for DAGWalker — cycle detection, ready tasks, available agents, and advance_dag.""" from __future__ import annotations +import asyncio from unittest.mock import AsyncMock import pytest @@ -206,6 +207,8 @@ async def test_advance_dag_calls_client_assign_task(self) -> None: task = make_task() agent = make_agent() mock_client = AsyncMock() + # Return the same agent list so the fresh-agent refresh (issue #196) is a no-op. + mock_client.get_agents = AsyncMock(return_value=[agent]) walker = DAGWalker(tasks=[task], agents=[agent], client=mock_client) assignments = await walker.advance_dag() @@ -259,6 +262,8 @@ async def test_advance_dag_client_called_for_each_assignment(self) -> None: agent1 = make_agent(id="a1") agent2 = make_agent(id="a2") mock_client = AsyncMock() + # Return the same agents so the fresh-agent refresh (issue #196) is a no-op. + mock_client.get_agents = AsyncMock(return_value=[agent1, agent2]) walker = DAGWalker( tasks=[task1, task2], agents=[agent1, agent2], client=mock_client ) @@ -295,3 +300,164 @@ async def test_advance_dag_raises_on_unknown_dependency(self) -> None: walker = DAGWalker(tasks=[task], agents=[agent]) with pytest.raises(ValueError, match="Unknown dependency"): await walker.advance_dag() + + +class TestAdvanceDagFreshAgents: + """Tests for issue #196 — advance_dag() calls get_agents() for a fresh agent list.""" + + async def test_advance_dag_calls_get_agents_when_available(self) -> None: + """advance_dag() calls client.get_agents() and uses the returned list.""" + task = make_task() + stale_agent = make_agent(id="stale", current_task_id="busy") + fresh_agent = make_agent(id="fresh") + mock_client = AsyncMock() + mock_client.get_agents = AsyncMock(return_value=[fresh_agent]) + walker = DAGWalker(tasks=[task], agents=[stale_agent], client=mock_client) + + assignments = await walker.advance_dag() + + mock_client.get_agents.assert_awaited_once() + assert len(assignments) == 1 + assert assignments[0][1].id == "fresh" + + async def test_advance_dag_skips_get_agents_when_no_client(self) -> None: + """advance_dag() works without a client — uses self.agents unchanged.""" + task = make_task() + agent = make_agent() + walker = DAGWalker(tasks=[task], agents=[agent], client=None) + + assignments = await walker.advance_dag() + + assert len(assignments) == 1 + assert assignments[0][1] is agent + + async def test_advance_dag_skips_get_agents_when_client_lacks_method(self) -> None: + """advance_dag() falls back to self.agents when client has no get_agents().""" + from unittest.mock import MagicMock + + task = make_task() + agent = make_agent() + # Build a client that only has assign_task (an async callable) — no get_agents. + mock_client = MagicMock() + mock_client.assign_task = AsyncMock() + del mock_client.get_agents # ensure hasattr returns False + walker = DAGWalker(tasks=[task], agents=[agent], client=mock_client) + + assignments = await walker.advance_dag() + + assert len(assignments) == 1 + assert assignments[0][1] is agent + + async def test_advance_dag_falls_back_on_get_agents_exception(self) -> None: + """If get_agents() raises, advance_dag() falls back to the cached list.""" + task = make_task() + cached_agent = make_agent(id="cached") + mock_client = AsyncMock() + mock_client.get_agents = AsyncMock(side_effect=RuntimeError("service down")) + walker = DAGWalker(tasks=[task], agents=[cached_agent], client=mock_client) + + assignments = await walker.advance_dag() + + # Falls back to cached list — the cached agent was available + assert len(assignments) == 1 + assert assignments[0][1].id == "cached" + + async def test_advance_dag_ignores_none_from_get_agents(self) -> None: + """If get_agents() returns None, advance_dag() keeps the existing agent list.""" + task = make_task() + cached_agent = make_agent(id="cached") + mock_client = AsyncMock() + mock_client.get_agents = AsyncMock(return_value=None) + walker = DAGWalker(tasks=[task], agents=[cached_agent], client=mock_client) + + assignments = await walker.advance_dag() + + assert len(assignments) == 1 + assert assignments[0][1].id == "cached" + + +class TestBackgroundScan: + """Tests for issue #98 — periodic background DAG scan safety net.""" + + async def test_start_background_scan_returns_task(self) -> None: + """start_background_scan() returns an asyncio.Task.""" + walker = DAGWalker(tasks=[], agents=[], scan_interval=0.05) + stop_event = asyncio.Event() + task = walker.start_background_scan(stop_event) + assert isinstance(task, asyncio.Task) + stop_event.set() + await task + + async def test_background_scan_calls_advance_dag(self) -> None: + """Background scan calls advance_dag() at least once per interval.""" + ready_task = make_task() + agent = make_agent() + walker = DAGWalker(tasks=[ready_task], agents=[agent], scan_interval=0.02) + + call_count = 0 + original_advance = walker.advance_dag + + async def counting_advance() -> list: + nonlocal call_count + call_count += 1 + return await original_advance() + + walker.advance_dag = counting_advance # type: ignore[method-assign] + + stop_event = asyncio.Event() + scan_task = walker.start_background_scan(stop_event) + await asyncio.sleep(0.07) # allow ~3 intervals to fire + stop_event.set() + await scan_task + + assert call_count >= 1 + + async def test_background_scan_survives_advance_dag_exception(self) -> None: + """A raised exception in advance_dag() must not terminate the scan loop.""" + walker = DAGWalker(tasks=[], agents=[], scan_interval=0.02) + call_count = 0 + + async def failing_advance() -> list: + nonlocal call_count + call_count += 1 + if call_count < 2: + raise RuntimeError("transient failure") + return [] + + walker.advance_dag = failing_advance # type: ignore[method-assign] + + stop_event = asyncio.Event() + scan_task = walker.start_background_scan(stop_event) + await asyncio.sleep(0.07) + stop_event.set() + await scan_task + + assert call_count >= 2 + + async def test_background_scan_exits_immediately_on_stop(self) -> None: + """If stop_event is pre-set, the loop exits before calling advance_dag().""" + walker = DAGWalker(tasks=[], agents=[], scan_interval=60.0) + call_count = 0 + + async def counting_advance() -> list: + nonlocal call_count + call_count += 1 + return [] + + walker.advance_dag = counting_advance # type: ignore[method-assign] + + stop_event = asyncio.Event() + stop_event.set() # pre-set before starting + scan_task = walker.start_background_scan(stop_event) + await scan_task + + assert call_count == 0 + + async def test_start_background_scan_stores_task_reference(self) -> None: + """start_background_scan() stores the task in self._scan_task.""" + walker = DAGWalker(tasks=[], agents=[], scan_interval=0.05) + stop_event = asyncio.Event() + task = walker.start_background_scan(stop_event) + assert walker._scan_task is task + stop_event.set() + await task