Skip to content

Commit cdf4542

Browse files
Grafi 53/improve async fix tests (#91)
* improve async call and tests * add more tests and improve invoke finish signal * improve codebase * update unit tests * update hitl tests * address comments * fix lint
1 parent fdddcb2 commit cdf4542

17 files changed

Lines changed: 2839 additions & 140 deletions

File tree

grafi/common/containers/container.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,38 +29,52 @@ def __init__(self) -> None:
2929
# Per-instance attributes:
3030
self._event_store: Optional[EventStore] = None
3131
self._tracer: Optional[Tracer] = None
32+
# Lock for thread-safe lazy initialization of properties
33+
self._init_lock: threading.Lock = threading.Lock()
3234

3335
def register_event_store(self, event_store: EventStore) -> None:
3436
"""Override the default EventStore implementation."""
35-
if isinstance(event_store, EventStoreInMemory):
36-
logger.warning(
37-
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
38-
)
39-
self._event_store = event_store
37+
with self._init_lock:
38+
if isinstance(event_store, EventStoreInMemory):
39+
logger.warning(
40+
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
41+
)
42+
self._event_store = event_store
4043

4144
def register_tracer(self, tracer: Tracer) -> None:
4245
"""Override the default Tracer implementation."""
43-
self._tracer = tracer
46+
with self._init_lock:
47+
self._tracer = tracer
4448

4549
@property
4650
def event_store(self) -> EventStore:
47-
if self._event_store is None:
48-
logger.warning(
49-
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
50-
)
51-
self._event_store = EventStoreInMemory()
52-
return self._event_store
51+
# Fast path: already initialized
52+
if self._event_store is not None:
53+
return self._event_store
54+
# Slow path: initialize with lock (double-checked locking)
55+
with self._init_lock:
56+
if self._event_store is None:
57+
logger.warning(
58+
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
59+
)
60+
self._event_store = EventStoreInMemory()
61+
return self._event_store
5362

5463
@property
5564
def tracer(self) -> Tracer:
56-
if self._tracer is None:
57-
self._tracer = setup_tracing(
58-
tracing_options=TracingOptions.AUTO,
59-
collector_endpoint="localhost",
60-
collector_port=4317,
61-
project_name="grafi-trace",
62-
)
63-
return self._tracer
65+
# Fast path: already initialized
66+
if self._tracer is not None:
67+
return self._tracer
68+
# Slow path: initialize with lock (double-checked locking)
69+
with self._init_lock:
70+
if self._tracer is None:
71+
self._tracer = setup_tracing(
72+
tracing_options=TracingOptions.AUTO,
73+
collector_endpoint="localhost",
74+
collector_port=4317,
75+
project_name="grafi-trace",
76+
)
77+
return self._tracer
6478

6579

6680
container: Container = Container()

grafi/common/models/async_result.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ def __init__(self, source: AsyncGenerator[ConsumeFromTopicEvent, None]):
4343
self._done = asyncio.Event()
4444
self._started = False
4545
self._exc: Optional[BaseException] = None
46+
self._producer_task: Optional[asyncio.Task] = None
4647

4748
def _ensure_started(self) -> None:
4849
if not self._started:
4950
loop = asyncio.get_running_loop()
50-
loop.create_task(self._producer())
51+
self._producer_task = loop.create_task(self._producer())
5152
self._started = True
5253

5354
async def _producer(self) -> None:
@@ -94,10 +95,20 @@ async def to_list(self) -> list[ConsumeFromTopicEvent]:
9495
return result if isinstance(result, list) else [result]
9596

9697
async def aclose(self) -> None:
97-
"""Attempt to close the underlying async generator (if any)."""
98+
"""Cancel producer task and close the underlying async generator."""
99+
# Cancel the producer task if it's running
100+
if self._producer_task is not None and not self._producer_task.done():
101+
self._producer_task.cancel()
102+
try:
103+
await self._producer_task
104+
except asyncio.CancelledError:
105+
# The task was cancelled by aclose(); a CancelledError here is expected.
106+
pass
107+
# Close the underlying source generator
98108
try:
99109
await self._source.aclose()
100110
except Exception:
111+
# Best-effort cleanup: ignore errors from closing the underlying source.
101112
pass
102113

103114

grafi/tools/llms/impl/claude_tool.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,29 +102,29 @@ async def invoke(
102102
input_data: Messages,
103103
) -> MsgsAGen:
104104
messages, tools = self.prepare_api_input(input_data)
105-
client = AsyncAnthropic(api_key=self.api_key)
106105

107106
try:
108-
if self.is_streaming:
109-
async with client.messages.stream(
110-
max_tokens=self.max_tokens,
111-
model=self.model,
112-
messages=messages,
113-
tools=tools,
114-
**self.chat_params,
115-
) as stream:
116-
async for event in stream:
117-
if event.type == "text":
118-
yield self.to_stream_messages(event.text)
119-
else:
120-
resp: AnthropicMessage = await client.messages.create(
121-
max_tokens=self.max_tokens,
122-
model=self.model,
123-
messages=messages,
124-
tools=tools,
125-
**self.chat_params,
126-
)
127-
yield self.to_messages(resp)
107+
async with AsyncAnthropic(api_key=self.api_key) as client:
108+
if self.is_streaming:
109+
async with client.messages.stream(
110+
max_tokens=self.max_tokens,
111+
model=self.model,
112+
messages=messages,
113+
tools=tools,
114+
**self.chat_params,
115+
) as stream:
116+
async for event in stream:
117+
if event.type == "text":
118+
yield self.to_stream_messages(event.text)
119+
else:
120+
resp: AnthropicMessage = await client.messages.create(
121+
max_tokens=self.max_tokens,
122+
model=self.model,
123+
messages=messages,
124+
tools=tools,
125+
**self.chat_params,
126+
)
127+
yield self.to_messages(resp)
128128

129129
except asyncio.CancelledError:
130130
raise

grafi/tools/llms/impl/gemini_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class GeminiTool(LLM):
6060
name: str = Field(default="GeminiTool")
6161
type: str = Field(default="GeminiTool")
6262
api_key: Optional[str] = Field(default_factory=lambda: os.getenv("GEMINI_API_KEY"))
63-
model: str = Field(default="gemini-2.0-flash-lite")
63+
model: str = Field(default="gemini-2.5-flash-lite")
6464

6565
@classmethod
6666
def builder(cls) -> "GeminiToolBuilder":

grafi/tools/llms/impl/openai_tool.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,31 +107,30 @@ async def invoke(
107107
) -> MsgsAGen:
108108
api_messages, api_tools = self.prepare_api_input(input_data)
109109
try:
110-
client = AsyncClient(api_key=self.api_key)
111-
112-
if self.is_streaming:
113-
async for chunk in await client.chat.completions.create(
114-
model=self.model,
115-
messages=api_messages,
116-
tools=api_tools,
117-
stream=True,
118-
**self.chat_params,
119-
):
120-
yield self.to_stream_messages(chunk)
121-
else:
122-
req_func = (
123-
client.chat.completions.create
124-
if not self.structured_output
125-
else client.beta.chat.completions.parse
126-
)
127-
response: ChatCompletion = await req_func(
128-
model=self.model,
129-
messages=api_messages,
130-
tools=api_tools,
131-
**self.chat_params,
132-
)
133-
134-
yield self.to_messages(response)
110+
async with AsyncClient(api_key=self.api_key) as client:
111+
if self.is_streaming:
112+
async for chunk in await client.chat.completions.create(
113+
model=self.model,
114+
messages=api_messages,
115+
tools=api_tools,
116+
stream=True,
117+
**self.chat_params,
118+
):
119+
yield self.to_stream_messages(chunk)
120+
else:
121+
req_func = (
122+
client.chat.completions.create
123+
if not self.structured_output
124+
else client.beta.chat.completions.parse
125+
)
126+
response: ChatCompletion = await req_func(
127+
model=self.model,
128+
messages=api_messages,
129+
tools=api_tools,
130+
**self.chat_params,
131+
)
132+
133+
yield self.to_messages(response)
135134
except asyncio.CancelledError:
136135
raise # let caller handle
137136
except OpenAIError as exc:

grafi/topics/queue_impl/in_mem_topic_event_queue.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async def fetch(
6666

6767
async with self._cond:
6868
# If timeout is 0 or None and no data, return immediately
69-
while not await self.can_consume(consumer_id):
69+
while not self._can_consume_unlocked(consumer_id):
7070
try:
7171
logger.debug(
7272
f"Consumer {consumer_id} waiting for new messages with timeout={timeout}"
@@ -109,8 +109,17 @@ async def reset(self) -> None:
109109
self._consumed = defaultdict(int)
110110
self._committed = defaultdict(lambda: -1)
111111

112+
def _can_consume_unlocked(self, consumer_id: str) -> bool:
113+
"""
114+
Internal check without lock. MUST be called with self._cond held.
115+
"""
116+
return self._consumed[consumer_id] < len(self._records)
117+
112118
async def can_consume(self, consumer_id: str) -> bool:
113119
"""
114120
Check if there are events available for consumption by a consumer asynchronously.
121+
122+
This method acquires the lock to ensure consistent reads of shared state.
115123
"""
116-
return self._consumed[consumer_id] < len(self._records)
124+
async with self._cond:
125+
return self._can_consume_unlocked(consumer_id)

grafi/topics/topic_base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,18 @@ async def publish_data(
7979
"""
8080
Publish data to the topic if it meets the condition.
8181
"""
82-
if self.condition(publish_event):
82+
try:
83+
condition_met = self.condition(publish_event)
84+
except Exception as e:
85+
# Condition evaluation failed (e.g., IndexError on empty data)
86+
# Treat as condition not met
87+
logger.debug(
88+
f"[{self.name}] Condition evaluation failed: {e}. "
89+
"Treating as condition not met."
90+
)
91+
condition_met = False
92+
93+
if condition_met:
8394
event = publish_event.model_copy(
8495
update={
8596
"name": self.name,

grafi/workflows/impl/async_node_tracker.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,55 @@ def __init__(self) -> None:
4444
self._cond = asyncio.Condition()
4545
self._quiescence_event = asyncio.Event()
4646

47-
# Work tracking (prevents premature quiescence before any work)
48-
self._total_committed: int = 0
49-
self._has_started: bool = False
50-
5147
# Force stop flag (for explicit workflow stop)
5248
self._force_stopped: bool = False
5349

5450
def reset(self) -> None:
55-
"""Reset for a new workflow run."""
51+
"""
52+
Reset for a new workflow run.
53+
54+
Note: This is a sync reset that replaces primitives. It should only be
55+
called when no coroutines are waiting on the old primitives (e.g., at
56+
the start of a new workflow invocation before any tasks are spawned).
57+
"""
5658
self._active.clear()
5759
self._processing_count.clear()
5860
self._uncommitted_messages = 0
5961
self._cond = asyncio.Condition()
6062
self._quiescence_event = asyncio.Event()
61-
self._total_committed = 0
62-
self._has_started = False
6363
self._force_stopped = False
6464

65+
async def reset_async(self) -> None:
66+
"""
67+
Reset for a new workflow run (async version).
68+
69+
This version properly wakes any waiting coroutines before resetting,
70+
preventing deadlocks if called while the workflow is still running.
71+
"""
72+
async with self._cond:
73+
# Wake all waiters so they can exit gracefully
74+
self._force_stopped = True
75+
self._quiescence_event.set()
76+
self._cond.notify_all()
77+
78+
# Give waiters a chance to wake up and exit
79+
await asyncio.sleep(0)
80+
81+
# Now safe to reset state
82+
async with self._cond:
83+
self._active.clear()
84+
self._processing_count.clear()
85+
self._uncommitted_messages = 0
86+
self._force_stopped = False
87+
self._quiescence_event.clear()
88+
6589
# ─────────────────────────────────────────────────────────────────────────
6690
# Node Lifecycle (called from _invoke_node)
6791
# ─────────────────────────────────────────────────────────────────────────
6892

6993
async def enter(self, node_name: str) -> None:
7094
"""Called when a node begins processing."""
7195
async with self._cond:
72-
self._has_started = True
7396
self._quiescence_event.clear()
7497
self._active.add(node_name)
7598
self._processing_count[node_name] += 1
@@ -94,7 +117,6 @@ async def on_messages_published(self, count: int = 1, source: str = "") -> None:
94117
if count <= 0:
95118
return
96119
async with self._cond:
97-
self._has_started = True
98120
self._quiescence_event.clear()
99121
self._uncommitted_messages += count
100122

@@ -112,13 +134,9 @@ async def on_messages_committed(self, count: int = 1, source: str = "") -> None:
112134
return
113135
async with self._cond:
114136
self._uncommitted_messages = max(0, self._uncommitted_messages - count)
115-
self._total_committed += count
116137
self._check_quiescence_unlocked()
117138

118-
logger.debug(
119-
f"Tracker: {count} messages committed from {source} "
120-
f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})"
121-
)
139+
logger.debug(f"Tracker: {count} messages committed from {source}")
122140
self._cond.notify_all()
123141

124142
# Aliases for clarity
@@ -144,14 +162,9 @@ def _check_quiescence_unlocked(self) -> None:
144162
logger.debug(
145163
f"Tracker: checking quiescence - active={list(self._active)}, "
146164
f"uncommitted={self._uncommitted_messages}, "
147-
f"has_started={self._has_started}, "
148-
f"total_committed={self._total_committed}, "
149165
f"is_quiescent={is_quiescent}"
150166
)
151167
if is_quiescent:
152-
logger.info(
153-
f"Tracker: quiescence detected (committed={self._total_committed})"
154-
)
155168
self._quiescence_event.set()
156169

157170
def _is_quiescent_unlocked(self) -> bool:
@@ -165,12 +178,12 @@ def _is_quiescent_unlocked(self) -> bool:
165178
- No messages waiting to be committed
166179
- At least some work was done
167180
"""
168-
return (
169-
not self._active
170-
and self._uncommitted_messages == 0
171-
and self._has_started
172-
and self._total_committed > 0
181+
is_quiescent = not self._active and self._uncommitted_messages == 0
182+
logger.debug(
183+
f"Tracker: _is_quiescent_unlocked check - active={list(self._active)}, "
184+
f"uncommitted={self._uncommitted_messages}, is_quiescent={is_quiescent}"
173185
)
186+
return is_quiescent
174187

175188
async def is_quiescent(self) -> bool:
176189
"""
@@ -263,6 +276,5 @@ async def get_metrics(self) -> Dict:
263276
return {
264277
"active_nodes": list(self._active),
265278
"uncommitted_messages": self._uncommitted_messages,
266-
"total_committed": self._total_committed,
267279
"is_quiescent": self._is_quiescent_unlocked(),
268280
}

0 commit comments

Comments
 (0)