Skip to content

Commit 8e54f84

Browse files
author
Dylan Huang
committed
Refactor SqliteEventBus for async event handling
- Updated SqliteEventBus to use asyncio for cross-process event listening, replacing the previous threading implementation. - Changed the processed field in the database from CharField to BooleanField for better data integrity. - Adjusted event processing logic to accommodate the new async structure, ensuring events are handled correctly across processes. - Enhanced test cases to support async operations and validate cross-process event communication.
1 parent 174b261 commit 8e54f84

File tree

5 files changed

+275
-107
lines changed

5 files changed

+275
-107
lines changed

eval_protocol/event_bus/sqlite_event_bus.py

Lines changed: 47 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import os
13
import threading
24
import time
35
from typing import Any, Optional
@@ -16,17 +18,14 @@ def __init__(self, db_path: Optional[str] = None):
1618

1719
# Use the same database as the evaluation row store
1820
if db_path is None:
19-
import os
20-
2121
from eval_protocol.directory_utils import find_eval_protocol_dir
2222

2323
eval_protocol_dir = find_eval_protocol_dir()
2424
db_path = os.path.join(eval_protocol_dir, "logs.db")
2525

26-
self._db = SqliteEventBusDatabase(db_path)
26+
self._db: SqliteEventBusDatabase = SqliteEventBusDatabase(db_path)
2727
self._running = False
28-
self._listener_thread: Optional[threading.Thread] = None
29-
self._process_id = str(uuid4())
28+
self._process_id = str(os.getpid())
3029

3130
def emit(self, event_type: str, data: Any) -> None:
3231
"""Emit an event to all subscribers.
@@ -64,67 +63,54 @@ def start_listening(self) -> None:
6463

6564
logger.debug("[CROSS_PROCESS_LISTEN] Starting cross-process event listening")
6665
self._running = True
67-
self._start_database_listener()
68-
logger.debug("[CROSS_PROCESS_LISTEN] Started database listener thread")
66+
loop = asyncio.get_running_loop()
67+
loop.create_task(self._database_listener_task())
68+
logger.debug("[CROSS_PROCESS_LISTEN] Started async database listener task")
6969

7070
def stop_listening(self) -> None:
7171
"""Stop listening for cross-process events."""
7272
logger.debug("[CROSS_PROCESS_LISTEN] Stopping cross-process event listening")
7373
self._running = False
74-
if self._listener_thread and self._listener_thread.is_alive():
75-
logger.debug("[CROSS_PROCESS_LISTEN] Waiting for listener thread to stop")
76-
self._listener_thread.join(timeout=1)
77-
logger.debug("[CROSS_PROCESS_LISTEN] Listener thread stopped")
78-
79-
def _start_database_listener(self) -> None:
80-
"""Start database-based event listener."""
81-
82-
def database_listener():
83-
logger.debug("[CROSS_PROCESS_LISTENER] Starting database listener loop")
84-
last_cleanup = time.time()
85-
86-
while self._running:
87-
try:
88-
# Get unprocessed events from other processes
89-
events = self._db.get_unprocessed_events(self._process_id)
90-
if events:
91-
logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events")
92-
93-
for event in events:
94-
if not self._running:
95-
break
96-
97-
try:
98-
logger.debug(
99-
f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}"
100-
)
101-
# Handle the event
102-
self._handle_cross_process_event(event["event_type"], event["data"])
103-
logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}")
104-
105-
# Mark as processed
106-
self._db.mark_event_processed(event["event_id"])
107-
logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed")
108-
109-
except Exception as e:
110-
logger.debug(f"[CROSS_PROCESS_LISTENER] Failed to process event {event['event_id']}: {e}")
111-
112-
# Clean up old events every hour
113-
current_time = time.time()
114-
if current_time - last_cleanup >= 3600:
115-
logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events")
116-
self._db.cleanup_old_events()
117-
last_cleanup = current_time
118-
119-
# Small sleep to prevent busy waiting
120-
time.sleep(0.1)
121-
122-
except Exception as e:
123-
logger.debug(f"[CROSS_PROCESS_LISTENER] Database listener error: {e}")
124-
time.sleep(1)
125-
126-
self._listener_thread = threading.Thread(target=database_listener, daemon=True)
127-
self._listener_thread.start()
74+
75+
async def _database_listener_task(self) -> None:
76+
"""Single database listener task that processes events and recreates itself."""
77+
if not self._running:
78+
# this should end the task loop
79+
logger.debug("[CROSS_PROCESS_LISTENER] Stopping database listener task")
80+
return
81+
82+
# Get unprocessed events from other processes
83+
events = self._db.get_unprocessed_events(str(self._process_id))
84+
if events:
85+
logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events")
86+
else:
87+
logger.debug(f"[CROSS_PROCESS_LISTENER] No unprocessed events found for process {self._process_id}")
88+
89+
for event in events:
90+
logger.debug(
91+
f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}"
92+
)
93+
# Handle the event
94+
self._handle_cross_process_event(event["event_type"], event["data"])
95+
logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}")
96+
97+
# Mark as processed
98+
self._db.mark_event_processed(event["event_id"])
99+
logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed")
100+
101+
# Clean up old events every hour
102+
current_time = time.time()
103+
if not hasattr(self, "_last_cleanup"):
104+
self._last_cleanup = current_time
105+
elif current_time - self._last_cleanup >= 3600:
106+
logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events")
107+
self._db.cleanup_old_events()
108+
self._last_cleanup = current_time
109+
110+
# Schedule the next task if still running
111+
await asyncio.sleep(1.0)
112+
loop = asyncio.get_running_loop()
113+
loop.create_task(self._database_listener_task())
128114

129115
def _handle_cross_process_event(self, event_type: str, data: Any) -> None:
130116
"""Handle events received from other processes."""

eval_protocol/event_bus/sqlite_event_bus_database.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, List
33
from uuid import uuid4
44

5-
from peewee import CharField, DateTimeField, Model, SqliteDatabase
5+
from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase
66
from playhouse.sqlite_ext import JSONField
77

88
from eval_protocol.event_bus.logger import logger
@@ -25,7 +25,7 @@ class Event(BaseModel): # type: ignore
2525
data = JSONField()
2626
timestamp = DateTimeField()
2727
process_id = CharField()
28-
processed = CharField(default="false") # Track if event has been processed
28+
processed = BooleanField(default=False) # Track if event has been processed
2929

3030
self._Event = Event
3131
self._db.connect()
@@ -46,7 +46,7 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
4646
data=serialized_data,
4747
timestamp=time.time(),
4848
process_id=process_id,
49-
processed="false",
49+
processed=False,
5050
)
5151
except Exception as e:
5252
logger.warning(f"Failed to publish event to database: {e}")
@@ -56,7 +56,7 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]:
5656
try:
5757
query = (
5858
self._Event.select()
59-
.where((self._Event.process_id != process_id) & (self._Event.processed == "false"))
59+
.where((self._Event.process_id != process_id) & (~self._Event.processed))
6060
.order_by(self._Event.timestamp)
6161
)
6262

@@ -80,16 +80,14 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]:
8080
def mark_event_processed(self, event_id: str) -> None:
8181
"""Mark an event as processed."""
8282
try:
83-
self._Event.update(processed="true").where(self._Event.event_id == event_id).execute()
83+
self._Event.update(processed=True).where(self._Event.event_id == event_id).execute()
8484
except Exception as e:
8585
logger.debug(f"Failed to mark event as processed: {e}")
8686

8787
def cleanup_old_events(self, max_age_hours: int = 24) -> None:
8888
"""Clean up old processed events."""
8989
try:
9090
cutoff_time = time.time() - (max_age_hours * 3600)
91-
self._Event.delete().where(
92-
(self._Event.processed == "true") & (self._Event.timestamp < cutoff_time)
93-
).execute()
91+
self._Event.delete().where((self._Event.processed) & (self._Event.timestamp < cutoff_time)).execute()
9492
except Exception as e:
9593
logger.debug(f"Failed to cleanup old events: {e}")

test_event_bus_helper.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#!/usr/bin/env python3
2+
"""Helper script for testing event bus cross-process communication."""
3+
4+
import asyncio
5+
import sys
6+
import json
7+
from eval_protocol.event_bus import SqliteEventBus
8+
from eval_protocol.models import EvaluationRow, InputMetadata
9+
10+
11+
async def listener_process(db_path: str):
12+
"""Run an event bus listener in a separate process."""
13+
try:
14+
event_bus = SqliteEventBus(db_path=db_path)
15+
16+
received_events = []
17+
18+
def test_listener(event_type: str, data):
19+
received_events.append((event_type, data))
20+
21+
event_bus.subscribe(test_listener)
22+
event_bus.start_listening()
23+
24+
# Wait for events for up to 5 seconds
25+
start_time = asyncio.get_event_loop().time()
26+
while asyncio.get_event_loop().time() - start_time < 5.0:
27+
await asyncio.sleep(0.1)
28+
if received_events:
29+
break
30+
31+
# Output results to stdout
32+
print(json.dumps(received_events))
33+
event_bus.stop_listening()
34+
35+
except Exception as e:
36+
print(f"Error in listener process: {e}", file=sys.stderr)
37+
sys.exit(1)
38+
39+
40+
async def emitter_process(db_path: str, event_type: str, data_json: str):
41+
"""Emit an event from a separate process."""
42+
try:
43+
event_bus = SqliteEventBus(db_path=db_path)
44+
45+
# Parse the data
46+
if data_json:
47+
data = json.loads(data_json)
48+
else:
49+
data = None
50+
51+
event_bus.emit(event_type, data)
52+
53+
except Exception as e:
54+
print(f"Error in emitter process: {e}", file=sys.stderr)
55+
sys.exit(1)
56+
57+
58+
if __name__ == "__main__":
59+
if len(sys.argv) < 3:
60+
print("Usage: python test_event_bus_helper.py <mode> <db_path> [event_type] [data_json]", file=sys.stderr)
61+
sys.exit(1)
62+
63+
mode = sys.argv[1]
64+
db_path = sys.argv[2]
65+
66+
if mode == "listener":
67+
asyncio.run(listener_process(db_path))
68+
elif mode == "emitter":
69+
event_type = sys.argv[3]
70+
data_json = sys.argv[4] if len(sys.argv) > 4 else ""
71+
asyncio.run(emitter_process(db_path, event_type, data_json))
72+
else:
73+
print(f"Unknown mode: {mode}", file=sys.stderr)
74+
sys.exit(1)

0 commit comments

Comments
 (0)