|
4 | 4 | import asyncio |
5 | 5 | from collections import defaultdict |
6 | 6 | from typing import Dict |
| 7 | +from typing import Optional |
| 8 | +from typing import Set |
| 9 | + |
| 10 | +from loguru import logger |
7 | 11 |
|
8 | 12 |
|
9 | 13 | class AsyncNodeTracker: |
| 14 | + """ |
| 15 | + Central tracker for workflow activity and quiescence detection. |
| 16 | +
|
| 17 | + Design: All tracking calls come from the ORCHESTRATOR layer, |
| 18 | + not from TopicBase. This keeps topics as pure message queues. |
| 19 | +
|
| 20 | + Quiescence = (no active nodes) AND (no uncommitted messages) AND (work done) |
| 21 | +
|
| 22 | + Usage in workflow: |
| 23 | + # In publish_events(): |
| 24 | + tracker.on_messages_published(len(published_events)) |
| 25 | +
|
| 26 | + # In _commit_events(): |
| 27 | + tracker.on_messages_committed(len(events)) |
| 28 | +
|
| 29 | + # In node processing: |
| 30 | + await tracker.enter(node_name) |
| 31 | + ... process ... |
| 32 | + await tracker.leave(node_name) |
| 33 | + """ |
| 34 | + |
10 | 35 | def __init__(self) -> None: |
11 | | - self._active: set[str] = set() |
12 | | - self._processing_count: Dict[str, int] = defaultdict( |
13 | | - int |
14 | | - ) # Track how many times each node processed |
| 36 | + # Node activity tracking |
| 37 | + self._active: Set[str] = set() |
| 38 | + self._processing_count: Dict[str, int] = defaultdict(int) |
| 39 | + |
| 40 | + # Message tracking (uncommitted = published but not yet committed) |
| 41 | + self._uncommitted_messages: int = 0 |
| 42 | + |
| 43 | + # Synchronization |
15 | 44 | self._cond = asyncio.Condition() |
16 | | - self._idle_event = asyncio.Event() |
17 | | - # Set the event initially since we start in idle state |
18 | | - self._idle_event.set() |
| 45 | + self._quiescence_event = asyncio.Event() |
| 46 | + |
| 47 | + # Work tracking (prevents premature quiescence before any work) |
| 48 | + self._total_committed: int = 0 |
| 49 | + self._has_started: bool = False |
| 50 | + |
| 51 | + # Force stop flag (for explicit workflow stop) |
| 52 | + self._force_stopped: bool = False |
19 | 53 |
|
20 | 54 | def reset(self) -> None: |
21 | | - """ |
22 | | - Reset the tracker to its initial state. |
23 | | - """ |
24 | | - self._active = set() |
25 | | - self._processing_count = defaultdict(int) |
| 55 | + """Reset for a new workflow run.""" |
| 56 | + self._active.clear() |
| 57 | + self._processing_count.clear() |
| 58 | + self._uncommitted_messages = 0 |
26 | 59 | self._cond = asyncio.Condition() |
27 | | - self._idle_event = asyncio.Event() |
28 | | - # Set the event initially since we start in idle state |
29 | | - self._idle_event.set() |
| 60 | + self._quiescence_event = asyncio.Event() |
| 61 | + self._total_committed = 0 |
| 62 | + self._has_started = False |
| 63 | + self._force_stopped = False |
| 64 | + |
| 65 | + # ───────────────────────────────────────────────────────────────────────── |
| 66 | + # Node Lifecycle (called from _invoke_node) |
| 67 | + # ───────────────────────────────────────────────────────────────────────── |
30 | 68 |
|
31 | 69 | async def enter(self, node_name: str) -> None: |
| 70 | + """Called when a node begins processing.""" |
32 | 71 | async with self._cond: |
33 | | - self._idle_event.clear() |
| 72 | + self._has_started = True |
| 73 | + self._quiescence_event.clear() |
34 | 74 | self._active.add(node_name) |
35 | 75 | self._processing_count[node_name] += 1 |
36 | 76 |
|
37 | 77 | async def leave(self, node_name: str) -> None: |
| 78 | + """Called when a node finishes processing.""" |
38 | 79 | async with self._cond: |
39 | 80 | self._active.discard(node_name) |
40 | | - if not self._active: |
41 | | - self._idle_event.set() |
42 | | - self._cond.notify_all() |
| 81 | + self._check_quiescence_unlocked() |
| 82 | + self._cond.notify_all() |
43 | 83 |
|
44 | | - async def wait_idle_event(self) -> None: |
| 84 | + # ───────────────────────────────────────────────────────────────────────── |
| 85 | + # Message Tracking (called from orchestrator utilities) |
| 86 | + # ───────────────────────────────────────────────────────────────────────── |
| 87 | + |
| 88 | + async def on_messages_published(self, count: int = 1, source: str = "") -> None: |
| 89 | + """ |
| 90 | + Called when messages are published to topics. |
| 91 | +
|
| 92 | + Call site: publish_events() in utils.py |
| 93 | + """ |
| 94 | + if count <= 0: |
| 95 | + return |
| 96 | + async with self._cond: |
| 97 | + self._has_started = True |
| 98 | + self._quiescence_event.clear() |
| 99 | + self._uncommitted_messages += count |
| 100 | + |
| 101 | + logger.debug( |
| 102 | + f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})" |
| 103 | + ) |
| 104 | + |
| 105 | + async def on_messages_committed(self, count: int = 1, source: str = "") -> None: |
45 | 106 | """ |
46 | | - Wait until the tracker is idle, meaning no active nodes. |
47 | | - This is useful for synchronization points in workflows. |
| 107 | + Called when messages are committed (consumed and acknowledged). |
| 108 | +
|
| 109 | + Call site: _commit_events() in EventDrivenWorkflow |
| 110 | + """ |
| 111 | + if count <= 0: |
| 112 | + return |
| 113 | + async with self._cond: |
| 114 | + self._uncommitted_messages = max(0, self._uncommitted_messages - count) |
| 115 | + self._total_committed += count |
| 116 | + self._check_quiescence_unlocked() |
| 117 | + |
| 118 | + logger.debug( |
| 119 | + f"Tracker: {count} messages committed from {source} " |
| 120 | + f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})" |
| 121 | + ) |
| 122 | + self._cond.notify_all() |
| 123 | + |
| 124 | + # Aliases for clarity |
| 125 | + async def on_message_published(self) -> None: |
| 126 | + """Single message version.""" |
| 127 | + await self.on_messages_published(1) |
| 128 | + |
| 129 | + async def on_message_committed(self) -> None: |
| 130 | + """Single message version.""" |
| 131 | + await self.on_messages_committed(1) |
| 132 | + |
| 133 | + # ───────────────────────────────────────────────────────────────────────── |
| 134 | + # Quiescence Detection |
| 135 | + # ───────────────────────────────────────────────────────────────────────── |
| 136 | + |
| 137 | + def _check_quiescence_unlocked(self) -> None: |
| 138 | + """ |
| 139 | + Check and signal quiescence if all conditions met. |
| 140 | +
|
| 141 | + MUST be called with self._cond lock held. |
| 142 | + """ |
| 143 | + is_quiescent = self._is_quiescent_unlocked() |
| 144 | + logger.debug( |
| 145 | + f"Tracker: checking quiescence - active={list(self._active)}, " |
| 146 | + f"uncommitted={self._uncommitted_messages}, " |
| 147 | + f"has_started={self._has_started}, " |
| 148 | + f"total_committed={self._total_committed}, " |
| 149 | + f"is_quiescent={is_quiescent}" |
| 150 | + ) |
| 151 | + if is_quiescent: |
| 152 | + logger.info( |
| 153 | + f"Tracker: quiescence detected (committed={self._total_committed})" |
| 154 | + ) |
| 155 | + self._quiescence_event.set() |
| 156 | + |
| 157 | + def _is_quiescent_unlocked(self) -> bool: |
48 | 158 | """ |
49 | | - await self._idle_event.wait() |
| 159 | + Internal quiescence check without lock. |
50 | 160 |
|
51 | | - def is_idle(self) -> bool: |
52 | | - return not self._active |
| 161 | + MUST be called with self._cond lock held. |
53 | 162 |
|
54 | | - def get_activity_count(self) -> int: |
55 | | - """Get total processing count across all nodes""" |
56 | | - return sum(self._processing_count.values()) |
| 163 | + True when workflow is truly idle: |
| 164 | + - No nodes actively processing |
| 165 | + - No messages waiting to be committed |
| 166 | + - At least some work was done |
| 167 | + """ |
| 168 | + return ( |
| 169 | + not self._active |
| 170 | + and self._uncommitted_messages == 0 |
| 171 | + and self._has_started |
| 172 | + and self._total_committed > 0 |
| 173 | + ) |
| 174 | + |
| 175 | + async def is_quiescent(self) -> bool: |
| 176 | + """ |
| 177 | + True when workflow is truly idle: |
| 178 | + - No nodes actively processing |
| 179 | + - No messages waiting to be committed |
| 180 | + - At least some work was done |
| 181 | +
|
| 182 | + This method acquires the lock to ensure consistent reads. |
| 183 | + """ |
| 184 | + async with self._cond: |
| 185 | + return self._is_quiescent_unlocked() |
| 186 | + |
| 187 | + def _should_terminate_unlocked(self) -> bool: |
| 188 | + """ |
| 189 | + Internal termination check without lock. |
| 190 | +
|
| 191 | + MUST be called with self._cond lock held. |
| 192 | + """ |
| 193 | + return self._is_quiescent_unlocked() or self._force_stopped |
| 194 | + |
| 195 | + async def should_terminate(self) -> bool: |
| 196 | + """ |
| 197 | + True when workflow should stop iteration. |
| 198 | + Either natural quiescence or explicit force stop. |
| 199 | +
|
| 200 | + This method acquires the lock to ensure consistent reads. |
| 201 | + """ |
| 202 | + async with self._cond: |
| 203 | + return self._should_terminate_unlocked() |
| 204 | + |
| 205 | + async def force_stop(self) -> None: |
| 206 | + """ |
| 207 | + Force the workflow to stop immediately (async version with lock). |
| 208 | + Called when workflow.stop() is invoked from async context. |
| 209 | + """ |
| 210 | + async with self._cond: |
| 211 | + logger.info("Tracker: force stop requested") |
| 212 | + self._force_stopped = True |
| 213 | + self._quiescence_event.set() |
| 214 | + self._cond.notify_all() |
| 215 | + |
| 216 | + def force_stop_sync(self) -> None: |
| 217 | + """ |
| 218 | + Force the workflow to stop immediately (sync version). |
| 219 | +
|
| 220 | + This is a synchronous version for use from sync contexts (e.g., stop() method). |
| 221 | + It sets the stop flag and event without acquiring the async lock. |
| 222 | + This is safe because: |
| 223 | + 1. Setting _force_stopped to True is atomic for the stop signal |
| 224 | + 2. asyncio.Event.set() is thread-safe |
| 225 | + 3. Readers will see the updated state on their next lock acquisition |
| 226 | + """ |
| 227 | + logger.info("Tracker: force stop requested (sync)") |
| 228 | + self._force_stopped = True |
| 229 | + self._quiescence_event.set() |
| 230 | + |
| 231 | + async def is_idle(self) -> bool: |
| 232 | + """Legacy: just checks if no active nodes.""" |
| 233 | + async with self._cond: |
| 234 | + return not self._active |
| 235 | + |
| 236 | + async def wait_for_quiescence(self, timeout: Optional[float] = None) -> bool: |
| 237 | + """Wait until quiescent. Returns False on timeout.""" |
| 238 | + try: |
| 239 | + if timeout: |
| 240 | + await asyncio.wait_for(self._quiescence_event.wait(), timeout) |
| 241 | + else: |
| 242 | + await self._quiescence_event.wait() |
| 243 | + return True |
| 244 | + except asyncio.TimeoutError: |
| 245 | + return False |
| 246 | + |
| 247 | + async def wait_idle_event(self) -> None: |
| 248 | + """Legacy compatibility.""" |
| 249 | + await self._quiescence_event.wait() |
| 250 | + |
| 251 | + # ───────────────────────────────────────────────────────────────────────── |
| 252 | + # Metrics |
| 253 | + # ───────────────────────────────────────────────────────────────────────── |
| 254 | + |
| 255 | + async def get_activity_count(self) -> int: |
| 256 | + """Total processing count across all nodes.""" |
| 257 | + async with self._cond: |
| 258 | + return sum(self._processing_count.values()) |
| 259 | + |
| 260 | + async def get_metrics(self) -> Dict: |
| 261 | + """Detailed metrics for debugging.""" |
| 262 | + async with self._cond: |
| 263 | + return { |
| 264 | + "active_nodes": list(self._active), |
| 265 | + "uncommitted_messages": self._uncommitted_messages, |
| 266 | + "total_committed": self._total_committed, |
| 267 | + "is_quiescent": self._is_quiescent_unlocked(), |
| 268 | + } |
0 commit comments