Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ProjectKeystoneConan(ConanFile):
def requirements(self) -> None:
self.requires("spdlog/1.12.0")
self.requires("concurrentqueue/1.0.4")
self.requires("cnats/3.9.3")
self.requires("cnats/3.12.0")
if self.options.with_grpc:
self.requires("yaml-cpp/0.8.0")

Expand Down
61 changes: 61 additions & 0 deletions src/keystone/dag_walker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import asyncio
import logging
from typing import Any, Optional

from .models import Agent, Task, TERMINAL_STATUSES

logger = logging.getLogger(__name__)


class DAGWalker:
"""Walks a task DAG and assigns ready tasks to available agents."""
Expand All @@ -13,10 +17,13 @@ def __init__(
tasks: list[Task],
agents: list[Agent],
client: Optional[Any] = None,
scan_interval: float = 60.0,
) -> None:
self.tasks = tasks
self.agents = agents
self.client = client
self.scan_interval = scan_interval
self._scan_task: Optional[asyncio.Task[None]] = None

def get_available_agents(self) -> list[Agent]:
"""Return agents that are active, online, and not currently assigned a task."""
Expand Down Expand Up @@ -125,6 +132,10 @@ def validate_no_cycles(self) -> bool:
async def advance_dag(self) -> list[tuple[Task, Agent]]:
"""Assign ready tasks to available agents, returning the assignments made.

When a ``client`` is configured and it exposes a ``get_agents()`` coroutine,
the method fetches a fresh agent list before evaluating availability (issue
#196). This prevents double-assignment races caused by a stale cached list.

Raises :exc:`ValueError` if the task DAG contains a cycle or if any task
references an unknown dependency ID. Agents are marked busy immediately upon
selection so a single call cannot double-assign the same agent even if the
Expand All @@ -134,6 +145,19 @@ async def advance_dag(self) -> list[tuple[Task, Agent]]:
if cycle_path:
raise ValueError(f"Cycle detected in task DAG: {cycle_path}")

# Issue #196: refresh the agent list from the client before assigning so
# that stale in-memory state doesn't cause double-assignment.
if self.client is not None and hasattr(self.client, "get_agents"):
try:
fresh_agents = await self.client.get_agents()
if fresh_agents is not None:
self.agents = fresh_agents
except Exception: # noqa: BLE001
logger.warning(
"advance_dag: get_agents() failed — falling back to cached list",
exc_info=True,
)

assignments: list[tuple[Task, Agent]] = []
available = self.get_available_agents()

Expand All @@ -153,3 +177,40 @@ async def advance_dag(self) -> list[tuple[Task, Agent]]:
assignments.append((task, agent))

return assignments

async def _background_scan_loop(self, stop_event: asyncio.Event) -> None:
"""Periodically call :meth:`advance_dag` as a safety net (issue #98).

Uses ``asyncio.wait_for`` on *stop_event* so the loop wakes up
immediately on shutdown rather than sleeping for the full interval.
"""
logger.info(
"DAGWalker background scan started (interval=%.1fs)", self.scan_interval
)
while not stop_event.is_set():
try:
await asyncio.wait_for(
stop_event.wait(), timeout=self.scan_interval
)
# stop_event was set — exit cleanly.
break
except asyncio.TimeoutError:
pass
try:
await self.advance_dag()
except Exception: # noqa: BLE001
logger.exception("DAGWalker background scan failed — continuing")
logger.info("DAGWalker background scan stopped")

def start_background_scan(
self, stop_event: asyncio.Event
) -> "asyncio.Task[None]":
"""Schedule a background scan loop and return the created task (issue #98).

The caller is responsible for cancelling or awaiting the returned task on
shutdown. Passing *stop_event* allows a graceful, prompt exit.
"""
self._scan_task = asyncio.create_task(
self._background_scan_loop(stop_event), name="dag-background-scan"
)
return self._scan_task
166 changes: 166 additions & 0 deletions tests/test_dag_walker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for DAGWalker — cycle detection, ready tasks, available agents, and advance_dag."""
from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock

import pytest
Expand Down Expand Up @@ -206,6 +207,8 @@ async def test_advance_dag_calls_client_assign_task(self) -> None:
task = make_task()
agent = make_agent()
mock_client = AsyncMock()
# Return the same agent list so the fresh-agent refresh (issue #196) is a no-op.
mock_client.get_agents = AsyncMock(return_value=[agent])
walker = DAGWalker(tasks=[task], agents=[agent], client=mock_client)

assignments = await walker.advance_dag()
Expand Down Expand Up @@ -259,6 +262,8 @@ async def test_advance_dag_client_called_for_each_assignment(self) -> None:
agent1 = make_agent(id="a1")
agent2 = make_agent(id="a2")
mock_client = AsyncMock()
# Return the same agents so the fresh-agent refresh (issue #196) is a no-op.
mock_client.get_agents = AsyncMock(return_value=[agent1, agent2])
walker = DAGWalker(
tasks=[task1, task2], agents=[agent1, agent2], client=mock_client
)
Expand Down Expand Up @@ -295,3 +300,164 @@ async def test_advance_dag_raises_on_unknown_dependency(self) -> None:
walker = DAGWalker(tasks=[task], agents=[agent])
with pytest.raises(ValueError, match="Unknown dependency"):
await walker.advance_dag()


class TestAdvanceDagFreshAgents:
"""Tests for issue #196 — advance_dag() calls get_agents() for a fresh agent list."""

async def test_advance_dag_calls_get_agents_when_available(self) -> None:
"""advance_dag() calls client.get_agents() and uses the returned list."""
task = make_task()
stale_agent = make_agent(id="stale", current_task_id="busy")
fresh_agent = make_agent(id="fresh")
mock_client = AsyncMock()
mock_client.get_agents = AsyncMock(return_value=[fresh_agent])
walker = DAGWalker(tasks=[task], agents=[stale_agent], client=mock_client)

assignments = await walker.advance_dag()

mock_client.get_agents.assert_awaited_once()
assert len(assignments) == 1
assert assignments[0][1].id == "fresh"

async def test_advance_dag_skips_get_agents_when_no_client(self) -> None:
"""advance_dag() works without a client — uses self.agents unchanged."""
task = make_task()
agent = make_agent()
walker = DAGWalker(tasks=[task], agents=[agent], client=None)

assignments = await walker.advance_dag()

assert len(assignments) == 1
assert assignments[0][1] is agent

async def test_advance_dag_skips_get_agents_when_client_lacks_method(self) -> None:
"""advance_dag() falls back to self.agents when client has no get_agents()."""
from unittest.mock import MagicMock

task = make_task()
agent = make_agent()
# Build a client that only has assign_task (an async callable) — no get_agents.
mock_client = MagicMock()
mock_client.assign_task = AsyncMock()
del mock_client.get_agents # ensure hasattr returns False
walker = DAGWalker(tasks=[task], agents=[agent], client=mock_client)

assignments = await walker.advance_dag()

assert len(assignments) == 1
assert assignments[0][1] is agent

async def test_advance_dag_falls_back_on_get_agents_exception(self) -> None:
"""If get_agents() raises, advance_dag() falls back to the cached list."""
task = make_task()
cached_agent = make_agent(id="cached")
mock_client = AsyncMock()
mock_client.get_agents = AsyncMock(side_effect=RuntimeError("service down"))
walker = DAGWalker(tasks=[task], agents=[cached_agent], client=mock_client)

assignments = await walker.advance_dag()

# Falls back to cached list — the cached agent was available
assert len(assignments) == 1
assert assignments[0][1].id == "cached"

async def test_advance_dag_ignores_none_from_get_agents(self) -> None:
"""If get_agents() returns None, advance_dag() keeps the existing agent list."""
task = make_task()
cached_agent = make_agent(id="cached")
mock_client = AsyncMock()
mock_client.get_agents = AsyncMock(return_value=None)
walker = DAGWalker(tasks=[task], agents=[cached_agent], client=mock_client)

assignments = await walker.advance_dag()

assert len(assignments) == 1
assert assignments[0][1].id == "cached"


class TestBackgroundScan:
"""Tests for issue #98 — periodic background DAG scan safety net."""

async def test_start_background_scan_returns_task(self) -> None:
"""start_background_scan() returns an asyncio.Task."""
walker = DAGWalker(tasks=[], agents=[], scan_interval=0.05)
stop_event = asyncio.Event()
task = walker.start_background_scan(stop_event)
assert isinstance(task, asyncio.Task)
stop_event.set()
await task

async def test_background_scan_calls_advance_dag(self) -> None:
"""Background scan calls advance_dag() at least once per interval."""
ready_task = make_task()
agent = make_agent()
walker = DAGWalker(tasks=[ready_task], agents=[agent], scan_interval=0.02)

call_count = 0
original_advance = walker.advance_dag

async def counting_advance() -> list:
nonlocal call_count
call_count += 1
return await original_advance()

walker.advance_dag = counting_advance # type: ignore[method-assign]

stop_event = asyncio.Event()
scan_task = walker.start_background_scan(stop_event)
await asyncio.sleep(0.07) # allow ~3 intervals to fire
stop_event.set()
await scan_task

assert call_count >= 1

async def test_background_scan_survives_advance_dag_exception(self) -> None:
"""A raised exception in advance_dag() must not terminate the scan loop."""
walker = DAGWalker(tasks=[], agents=[], scan_interval=0.02)
call_count = 0

async def failing_advance() -> list:
nonlocal call_count
call_count += 1
if call_count < 2:
raise RuntimeError("transient failure")
return []

walker.advance_dag = failing_advance # type: ignore[method-assign]

stop_event = asyncio.Event()
scan_task = walker.start_background_scan(stop_event)
await asyncio.sleep(0.07)
stop_event.set()
await scan_task

assert call_count >= 2

async def test_background_scan_exits_immediately_on_stop(self) -> None:
"""If stop_event is pre-set, the loop exits before calling advance_dag()."""
walker = DAGWalker(tasks=[], agents=[], scan_interval=60.0)
call_count = 0

async def counting_advance() -> list:
nonlocal call_count
call_count += 1
return []

walker.advance_dag = counting_advance # type: ignore[method-assign]

stop_event = asyncio.Event()
stop_event.set() # pre-set before starting
scan_task = walker.start_background_scan(stop_event)
await scan_task

assert call_count == 0

async def test_start_background_scan_stores_task_reference(self) -> None:
"""start_background_scan() stores the task in self._scan_task."""
walker = DAGWalker(tasks=[], agents=[], scan_interval=0.05)
stop_event = asyncio.Event()
task = walker.start_background_scan(stop_event)
assert walker._scan_task is task
stop_event.set()
await task
Loading