Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ jobs:
--ignore=tests/pytest/test_svgbench.py \
--ignore=tests/pytest/test_livesvgbench.py \
--ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \
--ignore=tests/logging/test_elasticsearch_direct_http_handler.py \
--ignore=eval_protocol/benchmarks/ \
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: RemoteRolloutProcessor Propagate Status Test
name: Elasticsearch Tests

on:
push:
Expand All @@ -13,8 +13,8 @@ on:
workflow_dispatch: # Allow manual triggering

jobs:
remote-rollout-processor-propagate-status-smoke-test:
name: Fireworks Propagate Status Smoke Test
elasticsearch-tests:
name: Elasticsearch Integration Tests
runs-on: ubuntu-latest

steps:
Expand All @@ -36,10 +36,14 @@ jobs:
- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Run RemoteRolloutProcessor Propagate Status Smoke Test
- name: Run Elasticsearch Tests
env:
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
run: |
# Run Elasticsearch direct HTTP handler tests
uv run pytest tests/logging/test_elasticsearch_direct_http_handler.py -v --tb=short

# Run RemoteRolloutProcessor Propagate Status Smoke Test (also uses Elasticsearch)
uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \
-v --tb=short
1 change: 1 addition & 0 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def parse_args(args=None):
# Logs command
logs_parser = subparsers.add_parser("logs", help="Serve logs with file watching and real-time updates")
logs_parser.add_argument("--port", type=int, default=8000, help="Port to bind to (default: 8000)")
logs_parser.add_argument("--debug", action="store_true", help="Enable debug mode")

# Upload command
upload_parser = subparsers.add_parser(
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/cli_commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def logs_command(args):
print(f"🌐 URL: http://localhost:{port}")
print(f"🔌 WebSocket: ws://localhost:{port}/ws")
print(f"👀 Watching paths: {['current directory']}")
print(f"🔍 Debug mode: {args.debug}")
print("Press Ctrl+C to stop the server")
print("-" * 50)

Expand All @@ -25,7 +26,7 @@ def logs_command(args):
elasticsearch_config = ElasticsearchSetup().setup_elasticsearch()

try:
serve_logs(port=args.port, elasticsearch_config=elasticsearch_config)
serve_logs(port=args.port, elasticsearch_config=elasticsearch_config, debug=args.debug)
return 0
except KeyboardInterrupt:
print("\n🛑 Server stopped by user")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluati

def log(self, row: "EvaluationRow") -> None:
data = row.model_dump(exclude_none=True, mode="json")
rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown")
logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}")

self._store.upsert_row(data=data)
logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}")

try:
logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}")
event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data))
logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}")
except Exception as e:
# Avoid breaking storage due to event emission issues
logger.error(f"Failed to emit row_upserted event: {e}")
logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}")
pass

def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
Expand Down
117 changes: 67 additions & 50 deletions eval_protocol/event_bus/sqlite_event_bus.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import os
import threading
import time
from typing import Any, Optional
Expand All @@ -16,17 +18,14 @@ def __init__(self, db_path: Optional[str] = None):

# Use the same database as the evaluation row store
if db_path is None:
import os

from eval_protocol.directory_utils import find_eval_protocol_dir

eval_protocol_dir = find_eval_protocol_dir()
db_path = os.path.join(eval_protocol_dir, "logs.db")

self._db = SqliteEventBusDatabase(db_path)
self._db: SqliteEventBusDatabase = SqliteEventBusDatabase(db_path)
self._running = False
self._listener_thread: Optional[threading.Thread] = None
self._process_id = str(uuid4())
self._process_id = str(os.getpid())

def emit(self, event_type: str, data: Any) -> None:
"""Emit an event to all subscribers.
Expand All @@ -35,75 +34,93 @@ def emit(self, event_type: str, data: Any) -> None:
event_type: Type of event (e.g., "log")
data: Event data
"""
logger.debug(f"[CROSS_PROCESS_EMIT] Emitting event type: {event_type}")

# Call local listeners immediately
logger.debug(f"[CROSS_PROCESS_EMIT] Calling {len(self._listeners)} local listeners")
super().emit(event_type, data)
logger.debug("[CROSS_PROCESS_EMIT] Completed local listener calls")

# Publish to cross-process subscribers
logger.debug("[CROSS_PROCESS_EMIT] Publishing to cross-process subscribers")
self._publish_cross_process(event_type, data)
logger.debug("[CROSS_PROCESS_EMIT] Completed cross-process publish")

def _publish_cross_process(self, event_type: str, data: Any) -> None:
"""Publish event to cross-process subscribers via database."""
self._db.publish_event(event_type, data, self._process_id)
logger.debug(f"[CROSS_PROCESS_PUBLISH] Publishing event {event_type} to database")
try:
self._db.publish_event(event_type, data, self._process_id)
logger.debug(f"[CROSS_PROCESS_PUBLISH] Successfully published event {event_type} to database")
except Exception as e:
logger.error(f"[CROSS_PROCESS_PUBLISH] Failed to publish event {event_type} to database: {e}")

def start_listening(self) -> None:
"""Start listening for cross-process events."""
if self._running:
logger.debug("[CROSS_PROCESS_LISTEN] Already listening, skipping start")
return

logger.debug("[CROSS_PROCESS_LISTEN] Starting cross-process event listening")
self._running = True
self._start_database_listener()
loop = asyncio.get_running_loop()
loop.create_task(self._database_listener_task())
logger.debug("[CROSS_PROCESS_LISTEN] Started async database listener task")

def stop_listening(self) -> None:
"""Stop listening for cross-process events."""
logger.debug("[CROSS_PROCESS_LISTEN] Stopping cross-process event listening")
self._running = False
if self._listener_thread and self._listener_thread.is_alive():
self._listener_thread.join(timeout=1)

def _start_database_listener(self) -> None:
"""Start database-based event listener."""

def database_listener():
last_cleanup = time.time()

while self._running:
try:
# Get unprocessed events from other processes
events = self._db.get_unprocessed_events(self._process_id)

for event in events:
if not self._running:
break

try:
# Handle the event
self._handle_cross_process_event(event["event_type"], event["data"])

# Mark as processed
self._db.mark_event_processed(event["event_id"])

except Exception as e:
logger.debug(f"Failed to process event {event['event_id']}: {e}")

# Clean up old events every hour
current_time = time.time()
if current_time - last_cleanup >= 3600:
self._db.cleanup_old_events()
last_cleanup = current_time

# Small sleep to prevent busy waiting
time.sleep(0.1)

except Exception as e:
logger.debug(f"Database listener error: {e}")
time.sleep(1)
async def _database_listener_task(self) -> None:
"""Single database listener task that processes events and recreates itself."""
if not self._running:
# this should end the task loop
logger.debug("[CROSS_PROCESS_LISTENER] Stopping database listener task")
return

self._listener_thread = threading.Thread(target=database_listener, daemon=True)
self._listener_thread.start()
# Get unprocessed events from other processes
events = self._db.get_unprocessed_events(str(self._process_id))
if events:
logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events")
else:
logger.debug(f"[CROSS_PROCESS_LISTENER] No unprocessed events found for process {self._process_id}")

for event in events:
logger.debug(
f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}"
)
# Handle the event
self._handle_cross_process_event(event["event_type"], event["data"])
logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}")

# Mark as processed
self._db.mark_event_processed(event["event_id"])
logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed")

# Clean up old events every hour
current_time = time.time()
if not hasattr(self, "_last_cleanup"):
self._last_cleanup = current_time
elif current_time - self._last_cleanup >= 3600:
logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events")
self._db.cleanup_old_events()
self._last_cleanup = current_time

# Schedule the next task if still running
await asyncio.sleep(1.0)
loop = asyncio.get_running_loop()
loop.create_task(self._database_listener_task())

def _handle_cross_process_event(self, event_type: str, data: Any) -> None:
"""Handle events received from other processes."""
for listener in self._listeners:
logger.debug(f"[CROSS_PROCESS_HANDLE] Handling cross-process event type: {event_type}")
logger.debug(f"[CROSS_PROCESS_HANDLE] Calling {len(self._listeners)} listeners")

for i, listener in enumerate(self._listeners):
try:
logger.debug(f"[CROSS_PROCESS_HANDLE] Calling listener {i}")
listener(event_type, data)
logger.debug(f"[CROSS_PROCESS_HANDLE] Successfully called listener {i}")
except Exception as e:
logger.debug(f"Cross-process event listener failed for {event_type}: {e}")
logger.debug(f"[CROSS_PROCESS_HANDLE] Cross-process event listener {i} failed for {event_type}: {e}")
14 changes: 6 additions & 8 deletions eval_protocol/event_bus/sqlite_event_bus_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, List
from uuid import uuid4

from peewee import CharField, DateTimeField, Model, SqliteDatabase
from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.logger import logger
Expand All @@ -25,7 +25,7 @@ class Event(BaseModel): # type: ignore
data = JSONField()
timestamp = DateTimeField()
process_id = CharField()
processed = CharField(default="false") # Track if event has been processed
processed = BooleanField(default=False) # Track if event has been processed

self._Event = Event
self._db.connect()
Expand All @@ -46,7 +46,7 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
data=serialized_data,
timestamp=time.time(),
process_id=process_id,
processed="false",
processed=False,
)
except Exception as e:
logger.warning(f"Failed to publish event to database: {e}")
Expand All @@ -56,7 +56,7 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]:
try:
query = (
self._Event.select()
.where((self._Event.process_id != process_id) & (self._Event.processed == "false"))
.where((self._Event.process_id != process_id) & (~self._Event.processed))
.order_by(self._Event.timestamp)
)

Expand All @@ -80,16 +80,14 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]:
def mark_event_processed(self, event_id: str) -> None:
"""Mark an event as processed."""
try:
self._Event.update(processed="true").where(self._Event.event_id == event_id).execute()
self._Event.update(processed=True).where(self._Event.event_id == event_id).execute()
except Exception as e:
logger.debug(f"Failed to mark event as processed: {e}")

def cleanup_old_events(self, max_age_hours: int = 24) -> None:
"""Clean up old processed events."""
try:
cutoff_time = time.time() - (max_age_hours * 3600)
self._Event.delete().where(
(self._Event.processed == "true") & (self._Event.timestamp < cutoff_time)
).execute()
self._Event.delete().where((self._Event.processed) & (self._Event.timestamp < cutoff_time)).execute()
except Exception as e:
logger.debug(f"Failed to cleanup old events: {e}")
19 changes: 19 additions & 0 deletions eval_protocol/log_utils/elasticsearch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,25 @@ def delete_index(self) -> bool:
except Exception:
return False

def clear_index(self) -> bool:
"""Clear all documents from the index.

Returns:
bool: True if successful, False otherwise
"""
try:
# Delete all documents by query
response = self._make_request(
"POST", f"{self.index_url}/_delete_by_query", json_data={"query": {"match_all": {}}}
)
if response.status_code == 200:
# Refresh the index to ensure changes are visible
refresh_response = self._make_request("POST", f"{self.index_url}/_refresh")
return refresh_response.status_code == 200
return False
except Exception:
return False

def get_mapping(self) -> Optional[Dict[str, Any]]:
"""Get the index mapping.

Expand Down
6 changes: 3 additions & 3 deletions eval_protocol/log_utils/elasticsearch_direct_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any, Dict
from datetime import datetime
from datetime import datetime, timezone

from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
from .elasticsearch_client import ElasticsearchClient
Expand Down Expand Up @@ -36,8 +36,8 @@ def configure(self, elasticsearch_config: ElasticsearchConfig) -> None:
def emit(self, record: logging.LogRecord) -> None:
"""Emit a log record by scheduling it for async transmission."""
try:
# Create proper ISO 8601 timestamp
timestamp = datetime.fromtimestamp(record.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
# Create proper ISO 8601 timestamp in UTC
timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")

rollout_id = self._get_rollout_id(record)
logger.debug(f"Emitting log record: {record.getMessage()} with rollout_id: {rollout_id}")
Expand Down
Loading
Loading