diff --git a/eval_protocol/dataset_logger/__init__.py b/eval_protocol/dataset_logger/__init__.py index b3fc1cb2..d37abc85 100644 --- a/eval_protocol/dataset_logger/__init__.py +++ b/eval_protocol/dataset_logger/__init__.py @@ -1,14 +1,45 @@ import os from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter +from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore -# Allow disabling sqlite logger to avoid environment-specific constraints in simple CLI runs. -def _get_default_logger(): - if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() != "1": - return SqliteDatasetLoggerAdapter() +def get_evaluation_row_store(db_path: str) -> EvaluationRowStore: + """ + Factory to get the configured storage backend. + + Uses EP_STORAGE environment variable to select backend: + - "tinydb" (default): Uses TinyDB with JSON file storage + - "sqlite": Uses SQLite with peewee ORM + + Args: + db_path: Path to the database file + + Returns: + EvaluationRowStore implementation + """ + storage_type = os.getenv("EP_STORAGE", "tinydb").lower() + + if storage_type == "sqlite": + from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore + + return SqliteEvaluationRowStore(db_path) else: + from eval_protocol.dataset_logger.tinydb_evaluation_row_store import TinyDBEvaluationRowStore + + return TinyDBEvaluationRowStore(db_path) + + +def _get_default_db_filename() -> str: + """Get the default database filename based on storage backend.""" + storage_type = os.getenv("EP_STORAGE", "tinydb").lower() + return "logs.db" if storage_type == "sqlite" else "logs.json" + + +def _get_default_logger(): + """Get the default logger based on configuration.""" + # Allow disabling logger to avoid environment-specific constraints in simple CLI runs. + if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() == "1": class _NoOpLogger(DatasetLogger): def log(self, row): @@ -19,6 +50,11 @@ def read(self, rollout_id=None): return _NoOpLogger() + # Import here to avoid circular imports + from eval_protocol.dataset_logger.dataset_logger_adapter import DatasetLoggerAdapter + + return DatasetLoggerAdapter() + # Lazy property that creates the logger only when accessed class _LazyLogger(DatasetLogger): diff --git a/eval_protocol/dataset_logger/dataset_logger_adapter.py b/eval_protocol/dataset_logger/dataset_logger_adapter.py new file mode 100644 index 00000000..5414d63a --- /dev/null +++ b/eval_protocol/dataset_logger/dataset_logger_adapter.py @@ -0,0 +1,61 @@ +import os +from typing import TYPE_CHECKING, List, Optional + +from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE, DatasetLogger +from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore +from eval_protocol.directory_utils import find_eval_protocol_dir +from eval_protocol.event_bus import event_bus +from eval_protocol.event_bus.logger import logger + +if TYPE_CHECKING: + from eval_protocol.models import EvaluationRow + + +class DatasetLoggerAdapter(DatasetLogger): + """ + Dataset logger that uses the configured storage backend. + + The storage backend is selected based on the EP_STORAGE environment variable: + - "tinydb" (default): Uses TinyDB with JSON file storage + - "sqlite": Uses SQLite with peewee ORM + """ + + def __init__(self, db_path: Optional[str] = None, store: Optional[EvaluationRowStore] = None): + eval_protocol_dir = find_eval_protocol_dir() + if db_path is not None and store is not None: + raise ValueError("Provide only one of db_path or store, not both.") + if store is not None: + self.db_path = store.db_path + self._store = store + else: + # Import here to avoid circular imports + from eval_protocol.dataset_logger import _get_default_db_filename, get_evaluation_row_store + + default_db = _get_default_db_filename() + self.db_path = db_path if db_path is not None else os.path.join(eval_protocol_dir, default_db) + self._store = get_evaluation_row_store(self.db_path) + + def log(self, row: "EvaluationRow") -> None: + data = row.model_dump(exclude_none=True, mode="json") + rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown") + logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}") + + self._store.upsert_row(data=data) + logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}") + + try: + from eval_protocol.models import EvaluationRow as EvalRow + + logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}") + event_bus.emit(LOG_EVENT_TYPE, EvalRow(**data)) + logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}") + except Exception as e: + # Avoid breaking storage due to event emission issues + logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}") + pass + + def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]: + from eval_protocol.models import EvaluationRow + + results = self._store.read_rows(rollout_id=rollout_id) + return [EvaluationRow(**data) for data in results] diff --git a/eval_protocol/dataset_logger/evaluation_row_store.py b/eval_protocol/dataset_logger/evaluation_row_store.py new file mode 100644 index 00000000..28f841c4 --- /dev/null +++ b/eval_protocol/dataset_logger/evaluation_row_store.py @@ -0,0 +1,63 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + + +class EvaluationRowStore(ABC): + """ + Abstract base class for evaluation row storage. + + Stores arbitrary row data as JSON keyed by a unique string `rollout_id`. + Implementations can use different storage backends (SQLite, TinyDB, etc.) + """ + + @property + @abstractmethod + def db_path(self) -> str: + """Return the path to the database file.""" + pass + + @abstractmethod + def upsert_row(self, data: dict) -> None: + """ + Insert or update a row by rollout_id. + + Args: + data: Row data containing execution_metadata.rollout_id + """ + pass + + @abstractmethod + def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]: + """ + Read rows, optionally filtered by rollout_id. + + Args: + rollout_id: If provided, filter to this specific rollout + + Returns: + List of row data dictionaries + """ + pass + + @abstractmethod + def delete_row(self, rollout_id: str) -> int: + """ + Delete a row by rollout_id. + + Args: + rollout_id: The rollout_id to delete + + Returns: + Number of rows deleted + """ + pass + + @abstractmethod + def delete_all_rows(self) -> int: + """ + Delete all rows. + + Returns: + Number of rows deleted + """ + pass diff --git a/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py b/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py index 5f360bfc..3907b77a 100644 --- a/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py +++ b/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py @@ -1,45 +1,13 @@ -import os -from typing import List, Optional +""" +Backwards-compatible alias for DatasetLoggerAdapter. -from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE, DatasetLogger -from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore -from eval_protocol.directory_utils import find_eval_protocol_dir -from eval_protocol.event_bus import event_bus -from eval_protocol.event_bus.logger import logger -from eval_protocol.models import EvaluationRow +This module is kept for backwards compatibility. New code should use +DatasetLoggerAdapter from dataset_logger_adapter.py instead. +""" +from eval_protocol.dataset_logger.dataset_logger_adapter import DatasetLoggerAdapter -class SqliteDatasetLoggerAdapter(DatasetLogger): - def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluationRowStore] = None): - eval_protocol_dir = find_eval_protocol_dir() - if db_path is not None and store is not None: - raise ValueError("Provide only one of db_path or store, not both.") - if store is not None: - self.db_path = store.db_path - self._store = store - else: - self.db_path = db_path if db_path is not None else os.path.join(eval_protocol_dir, "logs.db") - self._store = SqliteEvaluationRowStore(self.db_path) +# Backwards-compatible alias +SqliteDatasetLoggerAdapter = DatasetLoggerAdapter - def log(self, row: "EvaluationRow") -> None: - data = row.model_dump(exclude_none=True, mode="json") - rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown") - logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}") - - self._store.upsert_row(data=data) - logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}") - - try: - logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}") - event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data)) - logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}") - except Exception as e: - # Avoid breaking storage due to event emission issues - logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}") - pass - - def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]: - from eval_protocol.models import EvaluationRow - - results = self._store.read_rows(rollout_id=rollout_id) - return [EvaluationRow(**data) for data in results] +__all__ = ["SqliteDatasetLoggerAdapter"] diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index a8e7b229..6dd2ff3f 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -1,13 +1,18 @@ import os from typing import List, Optional -from peewee import CharField, Model, SqliteDatabase -from playhouse.sqlite_ext import JSONField +try: + from peewee import CharField, Model, SqliteDatabase + from playhouse.sqlite_ext import JSONField +except ImportError: + raise ImportError( + "SQLite storage backend requires 'peewee' package. Install it with: pip install eval-protocol[sqlite_storage]" + ) -from eval_protocol.models import EvaluationRow +from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore -class SqliteEvaluationRowStore: +class SqliteEvaluationRowStore(EvaluationRowStore): """ Lightweight reusable SQLite store for evaluation rows. @@ -15,7 +20,10 @@ class SqliteEvaluationRowStore: """ def __init__(self, db_path: str): - os.makedirs(os.path.dirname(db_path), exist_ok=True) + # Handle case where db_path might be in the root directory + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) self._db_path = db_path self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"}) diff --git a/eval_protocol/dataset_logger/tinydb_evaluation_row_store.py b/eval_protocol/dataset_logger/tinydb_evaluation_row_store.py new file mode 100644 index 00000000..060eb080 --- /dev/null +++ b/eval_protocol/dataset_logger/tinydb_evaluation_row_store.py @@ -0,0 +1,120 @@ +import json +import logging +import os +import time +from typing import List, Optional + +from tinydb import Query, TinyDB +from tinyrecord.transaction import transaction + +from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore + +logger = logging.getLogger(__name__) + + +class TinyDBEvaluationRowStore(EvaluationRowStore): + """ + TinyDB-based evaluation row store. + + Stores data as plain JSON files, which are human-readable and + don't suffer from SQLite's binary format corruption issues. + + Uses tinyrecord for atomic transactions to handle concurrent access + from multiple processes safely. + """ + + def __init__(self, db_path: str): + # Handle case where db_path might be in the root directory + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + self._db_path = db_path + self._db = self._open_db_with_retry() + self._table = self._db.table("evaluation_rows") + + def _open_db_with_retry(self, max_retries: int = 3) -> TinyDB: + """Open TinyDB with retry logic to handle transient JSON decode errors.""" + last_error: Exception | None = None + for attempt in range(max_retries): + try: + return TinyDB(self._db_path) + except json.JSONDecodeError as e: + last_error = e + logger.warning(f"TinyDB JSON decode error on attempt {attempt + 1}: {e}") + # Wait a bit and retry - the file might be mid-write + time.sleep(0.1 * (attempt + 1)) + # Try to recover by removing the corrupted file + if attempt == max_retries - 1 and os.path.exists(self._db_path): + try: + logger.warning(f"Removing corrupted TinyDB file: {self._db_path}") + os.remove(self._db_path) + return TinyDB(self._db_path) + except Exception: + pass + raise last_error if last_error else RuntimeError("Failed to open TinyDB") + + @property + def db_path(self) -> str: + return self._db_path + + def upsert_row(self, data: dict) -> None: + rollout_id = data["execution_metadata"]["rollout_id"] + if rollout_id is None: + raise ValueError("execution_metadata.rollout_id is required to upsert a row") + + Row = Query() + condition = Row.execution_metadata.rollout_id == rollout_id + + # tinyrecord doesn't support upsert directly, so we implement it manually + # within a transaction for atomicity + with transaction(self._table) as tr: + # Clear cache to ensure fresh read in multi-process scenarios + self._table.clear_cache() + # Check if document exists + existing = self._table.search(condition) + if existing: + # Update existing document + tr.update(data, condition) + else: + # Insert new document + tr.insert(data) + + def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]: + """Read rows with retry logic for transient JSON decode errors.""" + max_retries = 3 + for attempt in range(max_retries): + try: + # Clear cache to ensure fresh read in multi-process scenarios + self._table.clear_cache() + if rollout_id is not None: + Row = Query() + return list(self._table.search(Row.execution_metadata.rollout_id == rollout_id)) + return list(self._table.all()) + except json.JSONDecodeError as e: + logger.warning(f"TinyDB JSON decode error on read attempt {attempt + 1}: {e}") + if attempt < max_retries - 1: + time.sleep(0.1 * (attempt + 1)) + else: + # Return empty list on final failure rather than crash + logger.warning("Failed to read TinyDB after retries, returning empty list") + return [] + return [] + + def delete_row(self, rollout_id: str) -> int: + Row = Query() + condition = Row.execution_metadata.rollout_id == rollout_id + + with transaction(self._table) as tr: + # Clear cache to ensure fresh read in multi-process scenarios + self._table.clear_cache() + # Check if document exists before removal to get accurate count + existing = self._table.search(condition) + if existing: + tr.remove(condition) + return len(existing) + return 0 + + def delete_all_rows(self) -> int: + count = len(self._table) + self._table.truncate() + return count diff --git a/eval_protocol/event_bus/__init__.py b/eval_protocol/event_bus/__init__.py index 86e572a9..5f2833fc 100644 --- a/eval_protocol/event_bus/__init__.py +++ b/eval_protocol/event_bus/__init__.py @@ -1,12 +1,47 @@ -# Global event bus instance - uses SqliteEventBus for cross-process functionality +# Global event bus instance - uses configured storage backend for cross-process functionality +import os from typing import Any, Callable + from eval_protocol.event_bus.event_bus import EventBus +from eval_protocol.event_bus.event_bus_database import EventBusDatabase + + +def get_event_bus_database(db_path: str) -> EventBusDatabase: + """ + Factory to get the configured event bus database backend. + + Uses EP_STORAGE environment variable to select backend: + - "tinydb" (default): Uses TinyDB with JSON file storage + - "sqlite": Uses SQLite with peewee ORM + + Args: + db_path: Path to the database file + + Returns: + EventBusDatabase implementation + """ + storage_type = os.getenv("EP_STORAGE", "tinydb").lower() + + if storage_type == "sqlite": + from eval_protocol.event_bus.sqlite_event_bus_database import SqliteEventBusDatabase + + return SqliteEventBusDatabase(db_path) + else: + from eval_protocol.event_bus.tinydb_event_bus_database import TinyDBEventBusDatabase + + return TinyDBEventBusDatabase(db_path) + + +def _get_default_db_filename() -> str: + """Get the default database filename based on storage backend.""" + storage_type = os.getenv("EP_STORAGE", "tinydb").lower() + return "logs.db" if storage_type == "sqlite" else "logs.json" def _get_default_event_bus(): - from eval_protocol.event_bus.sqlite_event_bus import SqliteEventBus + from eval_protocol.event_bus.cross_process_event_bus import CrossProcessEventBus - return SqliteEventBus() + return CrossProcessEventBus() # Lazy property that creates the event bus only when accessed diff --git a/eval_protocol/event_bus/cross_process_event_bus.py b/eval_protocol/event_bus/cross_process_event_bus.py new file mode 100644 index 00000000..3977d26c --- /dev/null +++ b/eval_protocol/event_bus/cross_process_event_bus.py @@ -0,0 +1,133 @@ +import asyncio +import os +import time +from typing import Any, Optional + +from eval_protocol.event_bus.event_bus import EventBus +from eval_protocol.event_bus.event_bus_database import EventBusDatabase +from eval_protocol.event_bus.logger import logger + + +class CrossProcessEventBus(EventBus): + """ + Cross-process event bus implementation using the configured storage backend. + + The storage backend is selected based on the EP_STORAGE environment variable: + - "tinydb" (default): Uses TinyDB with JSON file storage + - "sqlite": Uses SQLite with peewee ORM + """ + + def __init__(self, db_path: Optional[str] = None): + super().__init__() + + # Use the configured database backend + if db_path is None: + from eval_protocol.directory_utils import find_eval_protocol_dir + from eval_protocol.event_bus import _get_default_db_filename + + eval_protocol_dir = find_eval_protocol_dir() + db_path = os.path.join(eval_protocol_dir, _get_default_db_filename()) + + from eval_protocol.event_bus import get_event_bus_database + + self._db: EventBusDatabase = get_event_bus_database(db_path) + self._running = False + self._process_id = str(os.getpid()) + + def emit(self, event_type: str, data: Any) -> None: + """Emit an event to all subscribers. + + Args: + event_type: Type of event (e.g., "log") + data: Event data + """ + logger.debug(f"[CROSS_PROCESS_EMIT] Emitting event type: {event_type}") + + # Call local listeners immediately + logger.debug(f"[CROSS_PROCESS_EMIT] Calling {len(self._listeners)} local listeners") + super().emit(event_type, data) + logger.debug("[CROSS_PROCESS_EMIT] Completed local listener calls") + + # Publish to cross-process subscribers + logger.debug("[CROSS_PROCESS_EMIT] Publishing to cross-process subscribers") + self._publish_cross_process(event_type, data) + logger.debug("[CROSS_PROCESS_EMIT] Completed cross-process publish") + + def _publish_cross_process(self, event_type: str, data: Any) -> None: + """Publish event to cross-process subscribers via database.""" + logger.debug(f"[CROSS_PROCESS_PUBLISH] Publishing event {event_type} to database") + try: + self._db.publish_event(event_type, data, self._process_id) + logger.debug(f"[CROSS_PROCESS_PUBLISH] Successfully published event {event_type} to database") + except Exception as e: + logger.error(f"[CROSS_PROCESS_PUBLISH] Failed to publish event {event_type} to database: {e}") + + def start_listening(self) -> None: + """Start listening for cross-process events.""" + if self._running: + logger.debug("[CROSS_PROCESS_LISTEN] Already listening, skipping start") + return + + logger.debug("[CROSS_PROCESS_LISTEN] Starting cross-process event listening") + self._running = True + loop = asyncio.get_running_loop() + loop.create_task(self._database_listener_task()) + logger.debug("[CROSS_PROCESS_LISTEN] Started async database listener task") + + def stop_listening(self) -> None: + """Stop listening for cross-process events.""" + logger.debug("[CROSS_PROCESS_LISTEN] Stopping cross-process event listening") + self._running = False + + async def _database_listener_task(self) -> None: + """Single database listener task that processes events and recreates itself.""" + if not self._running: + # this should end the task loop + logger.debug("[CROSS_PROCESS_LISTENER] Stopping database listener task") + return + + # Get unprocessed events from other processes + events = self._db.get_unprocessed_events(str(self._process_id)) + if events: + logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events") + else: + logger.debug(f"[CROSS_PROCESS_LISTENER] No unprocessed events found for process {self._process_id}") + + for event in events: + logger.debug( + f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}" + ) + # Handle the event + self._handle_cross_process_event(event["event_type"], event["data"]) + logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}") + + # Mark as processed + self._db.mark_event_processed(event["event_id"]) + logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed") + + # Clean up old events every hour + current_time = time.time() + if not hasattr(self, "_last_cleanup"): + self._last_cleanup = current_time + elif current_time - self._last_cleanup >= 3600: + logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events") + self._db.cleanup_old_events() + self._last_cleanup = current_time + + # Schedule the next task if still running + await asyncio.sleep(1.0) + loop = asyncio.get_running_loop() + loop.create_task(self._database_listener_task()) + + def _handle_cross_process_event(self, event_type: str, data: Any) -> None: + """Handle events received from other processes.""" + logger.debug(f"[CROSS_PROCESS_HANDLE] Handling cross-process event type: {event_type}") + logger.debug(f"[CROSS_PROCESS_HANDLE] Calling {len(self._listeners)} listeners") + + for i, listener in enumerate(self._listeners): + try: + logger.debug(f"[CROSS_PROCESS_HANDLE] Calling listener {i}") + listener(event_type, data) + logger.debug(f"[CROSS_PROCESS_HANDLE] Successfully called listener {i}") + except Exception as e: + logger.debug(f"[CROSS_PROCESS_HANDLE] Cross-process event listener {i} failed for {event_type}: {e}") diff --git a/eval_protocol/event_bus/event_bus_database.py b/eval_protocol/event_bus/event_bus_database.py new file mode 100644 index 00000000..fb23dd79 --- /dev/null +++ b/eval_protocol/event_bus/event_bus_database.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from typing import Any, List + + +class EventBusDatabase(ABC): + """ + Abstract base class for cross-process event communication storage. + + Implementations can use different storage backends (SQLite, TinyDB, etc.) + """ + + @abstractmethod + def publish_event(self, event_type: str, data: Any, process_id: str) -> None: + """ + Publish an event to the database. + + Args: + event_type: Type of event (e.g., "log") + data: Event data (will be serialized to JSON) + process_id: ID of the publishing process + """ + pass + + @abstractmethod + def get_unprocessed_events(self, process_id: str) -> List[dict]: + """ + Get unprocessed events from other processes. + + Args: + process_id: Current process ID (events from this process are excluded) + + Returns: + List of event dictionaries with keys: event_id, event_type, data, timestamp, process_id + """ + pass + + @abstractmethod + def mark_event_processed(self, event_id: str) -> None: + """ + Mark an event as processed. + + Args: + event_id: The event ID to mark as processed + """ + pass + + @abstractmethod + def cleanup_old_events(self, max_age_hours: int = 24) -> None: + """ + Clean up old processed events. + + Args: + max_age_hours: Maximum age in hours for processed events + """ + pass diff --git a/eval_protocol/event_bus/sqlite_event_bus.py b/eval_protocol/event_bus/sqlite_event_bus.py index 88125a5b..24bd49e6 100644 --- a/eval_protocol/event_bus/sqlite_event_bus.py +++ b/eval_protocol/event_bus/sqlite_event_bus.py @@ -1,126 +1,13 @@ -import asyncio -import os -import threading -import time -from typing import Any, Optional -from uuid import uuid4 +""" +Backwards-compatible alias for CrossProcessEventBus. -from eval_protocol.event_bus.event_bus import EventBus -from eval_protocol.event_bus.logger import logger -from eval_protocol.event_bus.sqlite_event_bus_database import SqliteEventBusDatabase +This module is kept for backwards compatibility. New code should use +CrossProcessEventBus from cross_process_event_bus.py instead. +""" +from eval_protocol.event_bus.cross_process_event_bus import CrossProcessEventBus -class SqliteEventBus(EventBus): - """SQLite-based event bus implementation that supports cross-process communication.""" +# Backwards-compatible alias +SqliteEventBus = CrossProcessEventBus - def __init__(self, db_path: Optional[str] = None): - super().__init__() - - # Use the same database as the evaluation row store - if db_path is None: - from eval_protocol.directory_utils import find_eval_protocol_dir - - eval_protocol_dir = find_eval_protocol_dir() - db_path = os.path.join(eval_protocol_dir, "logs.db") - - self._db: SqliteEventBusDatabase = SqliteEventBusDatabase(db_path) - self._running = False - self._process_id = str(os.getpid()) - - def emit(self, event_type: str, data: Any) -> None: - """Emit an event to all subscribers. - - Args: - event_type: Type of event (e.g., "log") - data: Event data - """ - logger.debug(f"[CROSS_PROCESS_EMIT] Emitting event type: {event_type}") - - # Call local listeners immediately - logger.debug(f"[CROSS_PROCESS_EMIT] Calling {len(self._listeners)} local listeners") - super().emit(event_type, data) - logger.debug("[CROSS_PROCESS_EMIT] Completed local listener calls") - - # Publish to cross-process subscribers - logger.debug("[CROSS_PROCESS_EMIT] Publishing to cross-process subscribers") - self._publish_cross_process(event_type, data) - logger.debug("[CROSS_PROCESS_EMIT] Completed cross-process publish") - - def _publish_cross_process(self, event_type: str, data: Any) -> None: - """Publish event to cross-process subscribers via database.""" - logger.debug(f"[CROSS_PROCESS_PUBLISH] Publishing event {event_type} to database") - try: - self._db.publish_event(event_type, data, self._process_id) - logger.debug(f"[CROSS_PROCESS_PUBLISH] Successfully published event {event_type} to database") - except Exception as e: - logger.error(f"[CROSS_PROCESS_PUBLISH] Failed to publish event {event_type} to database: {e}") - - def start_listening(self) -> None: - """Start listening for cross-process events.""" - if self._running: - logger.debug("[CROSS_PROCESS_LISTEN] Already listening, skipping start") - return - - logger.debug("[CROSS_PROCESS_LISTEN] Starting cross-process event listening") - self._running = True - loop = asyncio.get_running_loop() - loop.create_task(self._database_listener_task()) - logger.debug("[CROSS_PROCESS_LISTEN] Started async database listener task") - - def stop_listening(self) -> None: - """Stop listening for cross-process events.""" - logger.debug("[CROSS_PROCESS_LISTEN] Stopping cross-process event listening") - self._running = False - - async def _database_listener_task(self) -> None: - """Single database listener task that processes events and recreates itself.""" - if not self._running: - # this should end the task loop - logger.debug("[CROSS_PROCESS_LISTENER] Stopping database listener task") - return - - # Get unprocessed events from other processes - events = self._db.get_unprocessed_events(str(self._process_id)) - if events: - logger.debug(f"[CROSS_PROCESS_LISTENER] Found {len(events)} unprocessed events") - else: - logger.debug(f"[CROSS_PROCESS_LISTENER] No unprocessed events found for process {self._process_id}") - - for event in events: - logger.debug( - f"[CROSS_PROCESS_LISTENER] Processing event {event['event_id']} of type {event['event_type']}" - ) - # Handle the event - self._handle_cross_process_event(event["event_type"], event["data"]) - logger.debug(f"[CROSS_PROCESS_LISTENER] Successfully processed event {event['event_id']}") - - # Mark as processed - self._db.mark_event_processed(event["event_id"]) - logger.debug(f"[CROSS_PROCESS_LISTENER] Marked event {event['event_id']} as processed") - - # Clean up old events every hour - current_time = time.time() - if not hasattr(self, "_last_cleanup"): - self._last_cleanup = current_time - elif current_time - self._last_cleanup >= 3600: - logger.debug("[CROSS_PROCESS_LISTENER] Cleaning up old events") - self._db.cleanup_old_events() - self._last_cleanup = current_time - - # Schedule the next task if still running - await asyncio.sleep(1.0) - loop = asyncio.get_running_loop() - loop.create_task(self._database_listener_task()) - - def _handle_cross_process_event(self, event_type: str, data: Any) -> None: - """Handle events received from other processes.""" - logger.debug(f"[CROSS_PROCESS_HANDLE] Handling cross-process event type: {event_type}") - logger.debug(f"[CROSS_PROCESS_HANDLE] Calling {len(self._listeners)} listeners") - - for i, listener in enumerate(self._listeners): - try: - logger.debug(f"[CROSS_PROCESS_HANDLE] Calling listener {i}") - listener(event_type, data) - logger.debug(f"[CROSS_PROCESS_HANDLE] Successfully called listener {i}") - except Exception as e: - logger.debug(f"[CROSS_PROCESS_HANDLE] Cross-process event listener {i} failed for {event_type}: {e}") +__all__ = ["SqliteEventBus"] diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 5d1f522a..3924cc6e 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -1,19 +1,31 @@ +import os import time from typing import Any, List from uuid import uuid4 -from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase -from playhouse.sqlite_ext import JSONField +try: + from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase + from playhouse.sqlite_ext import JSONField +except ImportError: + raise ImportError( + "SQLite storage backend requires 'peewee' package. Install it with: pip install eval-protocol[sqlite_storage]" + ) +from eval_protocol.event_bus.event_bus_database import EventBusDatabase from eval_protocol.event_bus.logger import logger -class SqliteEventBusDatabase: +class SqliteEventBusDatabase(EventBusDatabase): """SQLite database for cross-process event communication.""" def __init__(self, db_path: str): + # Handle case where db_path might be in the root directory + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) self._db_path = db_path - self._db = SqliteDatabase(db_path) + # Use WAL mode for better concurrent access + self._db = SqliteDatabase(db_path, pragmas={"journal_mode": "wal"}) class BaseModel(Model): class Meta: diff --git a/eval_protocol/event_bus/tinydb_event_bus_database.py b/eval_protocol/event_bus/tinydb_event_bus_database.py new file mode 100644 index 00000000..cc55426f --- /dev/null +++ b/eval_protocol/event_bus/tinydb_event_bus_database.py @@ -0,0 +1,143 @@ +import json +import os +import time +from typing import Any, List +from uuid import uuid4 + +from tinydb import Query, TinyDB +from tinyrecord.transaction import transaction + +from eval_protocol.event_bus.event_bus_database import EventBusDatabase +from eval_protocol.event_bus.logger import logger + + +class TinyDBEventBusDatabase(EventBusDatabase): + """ + TinyDB-based event bus database for cross-process event communication. + + Stores data as plain JSON files, which are human-readable and + don't suffer from SQLite's binary format corruption issues. + + Uses tinyrecord for atomic transactions to handle concurrent access + from multiple processes safely. + """ + + def __init__(self, db_path: str): + # Handle case where db_path might be in the root directory + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + self._db_path = db_path + self._db = self._open_db_with_retry() + self._table = self._db.table("events") + + def _open_db_with_retry(self, max_retries: int = 3) -> TinyDB: + """Open TinyDB with retry logic to handle transient JSON decode errors.""" + last_error: Exception | None = None + for attempt in range(max_retries): + try: + return TinyDB(self._db_path) + except json.JSONDecodeError as e: + last_error = e + logger.warning(f"TinyDB JSON decode error on attempt {attempt + 1}: {e}") + # Wait a bit and retry - the file might be mid-write + time.sleep(0.1 * (attempt + 1)) + # Try to recover by removing the corrupted file + if attempt == max_retries - 1 and os.path.exists(self._db_path): + try: + logger.warning(f"Removing corrupted TinyDB file: {self._db_path}") + os.remove(self._db_path) + return TinyDB(self._db_path) + except Exception: + pass + raise last_error if last_error else RuntimeError("Failed to open TinyDB") + + def publish_event(self, event_type: str, data: Any, process_id: str) -> None: + """Publish an event to the database using atomic transaction.""" + try: + # Serialize data, handling pydantic models + if hasattr(data, "model_dump"): + serialized_data = data.model_dump(mode="json", exclude_none=True) + else: + serialized_data = data + + document = { + "event_id": str(uuid4()), + "event_type": event_type, + "data": serialized_data, + "timestamp": time.time(), + "process_id": process_id, + "processed": False, + } + + # Use tinyrecord transaction for atomic, concurrent-safe insert + with transaction(self._table) as tr: + tr.insert(document) + except Exception as e: + logger.warning(f"Failed to publish event to database: {e}") + + def get_unprocessed_events(self, process_id: str) -> List[dict]: + """Get unprocessed events from other processes with retry logic.""" + max_retries = 3 + for attempt in range(max_retries): + try: + # Clear query cache to force fresh read from disk + # TinyDB caches query results, so we need to clear cache to see + # events written by other processes. The search() method will + # automatically call _read_table() on a cache miss. + self._table.clear_cache() + + Event = Query() + results = self._table.search((Event.process_id != process_id) & (Event.processed == False)) # noqa: E712 + + logger.debug( + f"TinyDBEventBusDatabase: Found {len(results)} unprocessed events for process_id: {process_id} in database: {self._db_path}" + ) + + events = [] + # Sort by timestamp + for event in sorted(results, key=lambda x: x.get("timestamp", 0)): + events.append( + { + "event_id": event["event_id"], + "event_type": event["event_type"], + "data": event["data"], + "timestamp": event["timestamp"], + "process_id": event["process_id"], + } + ) + + return events + except json.JSONDecodeError as e: + logger.warning(f"TinyDB JSON decode error on get_unprocessed_events attempt {attempt + 1}: {e}") + if attempt < max_retries - 1: + time.sleep(0.1 * (attempt + 1)) + else: + logger.warning("Failed to read events after retries, returning empty list") + return [] + except Exception as e: + logger.warning(f"Failed to get unprocessed events: {e}") + return [] + return [] + + def mark_event_processed(self, event_id: str) -> None: + """Mark an event as processed using atomic transaction.""" + try: + Event = Query() + with transaction(self._table) as tr: + tr.update({"processed": True}, Event.event_id == event_id) + except Exception as e: + logger.debug(f"Failed to mark event as processed: {e}") + + def cleanup_old_events(self, max_age_hours: int = 24) -> None: + """Clean up old processed events using atomic transaction.""" + try: + # Clear cache to see latest data before cleanup + self._table.clear_cache() + + cutoff_time = time.time() - (max_age_hours * 3600) + Event = Query() + with transaction(self._table) as tr: + tr.remove((Event.processed == True) & (Event.timestamp < cutoff_time)) # noqa: E712 + except Exception as e: + logger.debug(f"Failed to cleanup old events: {e}") diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 777547fe..7a51b121 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -217,11 +217,11 @@ def build_default_dataset_id(evaluator_id: str) -> str: def build_default_output_model(evaluator_id: str) -> str: base = evaluator_id.lower().replace("_", "-") uuid_suffix = str(uuid.uuid4())[:4] - + # suffix is "-rft-{4chars}" -> 9 chars suffix_len = 9 max_len = 63 - + # Check if we need to truncate if len(base) + suffix_len > max_len: # Calculate hash of the full base to preserve uniqueness @@ -229,10 +229,10 @@ def build_default_output_model(evaluator_id: str) -> str: # New structure: {truncated_base}-{hash}-{uuid_suffix} # Space needed for "-{hash}" is 1 + 6 = 7 hash_part_len = 7 - + allowed_base_len = max_len - suffix_len - hash_part_len truncated_base = base[:allowed_base_len].strip("-") - + return f"{truncated_base}-{hash_digest}-rft-{uuid_suffix}" return f"{base}-rft-{uuid_suffix}" diff --git a/eval_protocol/pytest/buffer.py b/eval_protocol/pytest/buffer.py index 88e2f2a5..6a94edac 100644 --- a/eval_protocol/pytest/buffer.py +++ b/eval_protocol/pytest/buffer.py @@ -5,11 +5,13 @@ from eval_protocol.models import EvaluationRow + class MicroBatchDataBuffer: """ Buffers evaluation results and writes them to disk in minibatches. Waits for all runs of a sample to complete before considering it ready and flush to disk. """ + def __init__(self, num_runs: int, batch_size: int, output_path_template: str): self.num_runs = num_runs self.batch_size = batch_size @@ -29,14 +31,14 @@ async def add_result(self, row: EvaluationRow): if not row_id: # Should not happen in valid EP workflow, unique row_id is required to group things together properly return - + self.pending_samples[row_id].append(row) - + if len(self.pending_samples[row_id]) >= self.num_runs: # Sample completed (all runs finished) completed_rows = self.pending_samples.pop(row_id) self.completed_samples_buffer.append(completed_rows) - + if len(self.completed_samples_buffer) >= self.batch_size: await self._flush_unsafe() @@ -56,13 +58,13 @@ async def _flush_unsafe(self): # Ensure directory exists os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) - + # Write flattened rows with open(output_path, mode) as f: for sample_rows in self.completed_samples_buffer: for row in sample_rows: f.write(row.model_dump_json() + "\n") - + self.completed_samples_buffer = [] self.batch_index += 1 @@ -79,4 +81,3 @@ async def close(self): if self.completed_samples_buffer: await self._flush_unsafe() - diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 2e6393dd..d3dec586 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -408,21 +408,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo rollout_processor.setup() - use_priority_scheduler = ( - ( - os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1" - and not isinstance(rollout_processor, MCPGymRolloutProcessor) - ) - ) + use_priority_scheduler = os.environ.get( + "EP_USE_PRIORITY_SCHEDULER", "0" + ) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor) if use_priority_scheduler: microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None) output_dir = os.environ.get("EP_OUTPUT_DIR", None) if microbatch_output_size and output_dir: - output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl")) + output_buffer = MicroBatchDataBuffer( + num_runs=num_runs, + batch_size=int(microbatch_output_size), + output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"), + ) else: output_buffer = None - + try: priority_results = await execute_priority_rollouts( dataset=data, @@ -440,12 +441,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo finally: if output_buffer: await output_buffer.close() - + for res in priority_results: run_idx = (res.execution_metadata.extra or {}).get("run_index", 0) if run_idx < len(all_results): all_results[run_idx].append(res) - + processed_rows_in_run.append(res) postprocess( @@ -461,6 +462,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo ) else: + async def execute_run(run_idx: int, config: RolloutProcessorConfig): nonlocal all_results @@ -513,7 +515,9 @@ async def _execute_groupwise_eval_with_semaphore( evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None group_rollout_ids = [ - r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + r.execution_metadata.rollout_id + for r in rows + if r.execution_metadata.rollout_id ] async with rollout_logging_context( primary_rollout_id or "", @@ -587,7 +591,9 @@ async def _collect_result(config, lst): row_groups[row.input_metadata.row_id].append(row) tasks = [] for _, rows in row_groups.items(): - tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) + tasks.append( + asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)) + ) results = [] for task in tasks: res = await task @@ -678,9 +684,9 @@ async def _collect_result(config, lst): # For other processors, create all tasks at once and run in parallel # Concurrency is now controlled by the shared semaphore in each rollout processor await run_tasks_with_run_progress(execute_run, num_runs, config) - + experiment_duration_seconds = time.perf_counter() - experiment_start_time - + # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them # rollout_id is used to differentiate the result from different completion_params if mode == "groupwise": @@ -716,15 +722,12 @@ async def _collect_result(config, lst): experiment_duration_seconds, ) - - if not all(r.evaluation_result is not None for run_results in all_results for r in run_results): raise AssertionError( "Some EvaluationRow instances are missing evaluation_result. " "Your @evaluation_test function must set `row.evaluation_result`" ) - except AssertionError: _log_eval_error( Status.eval_finished(), diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index 94d6f7fe..f68e1b86 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -365,7 +365,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) retry_tasks = rollout_processor([row], retry_config) result = await retry_tasks[0] - + # Apply post-processing quality checks if configured # This must be inside the retry function so ResponseQualityError can trigger retries if config.post_processor is not None: @@ -374,7 +374,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: except ResponseQualityError as quality_error: # Re-raise ResponseQualityError to trigger retry logic raise quality_error - + return result async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow: diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index a2244b2a..c2db55f6 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -81,7 +81,7 @@ class BackoffConfig: def get_backoff_decorator(self, exceptions: Set[Type[Exception]]): """Get the appropriate backoff decorator based on configuration. - + Args: exceptions: Set of exception types to retry """ @@ -141,9 +141,7 @@ def __post_init__(self): def get_backoff_decorator(self): """Get the backoff decorator configured for this exception handler.""" - return self.backoff_config.get_backoff_decorator( - self.retryable_exceptions - ) + return self.backoff_config.get_backoff_decorator(self.retryable_exceptions) def get_default_exception_handler_config() -> ExceptionHandlerConfig: diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index eaddacc5..b214075c 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -17,6 +17,7 @@ ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1" + @dataclass(order=True) class RolloutTask: """ @@ -26,34 +27,37 @@ class RolloutTask: 1 = Low Priority (e.g., starting a new sample) - row_index: Used to maintain dataset order for initial scheduling """ + priority: tuple[int, int] - + # Payload (excluded from comparison) row: EvaluationRow = field(compare=False) run_indices: List[int] = field(compare=False) # Which runs to execute in this task config: RolloutProcessorConfig = field(compare=False) - row_index: int = field(compare=False) # To track which sample this belongs to - + row_index: int = field(compare=False) # To track which sample this belongs to + # History for speculation (injected from previous micro-batches) history: List[str] = field(compare=False, default_factory=list) + class PriorityRolloutScheduler: """ Manages a priority queue of rollout tasks and a pool of workers. Ensures that once a sample starts processing, its subsequent micro-batches are prioritized to complete the sample as quickly as possible. """ + def __init__( self, rollout_processor: RolloutProcessor, max_concurrent_rollouts: int, active_logger: DatasetLogger, max_concurrent_evaluations: int, - eval_executor: TestFunction, # Callback to run evaluation + eval_executor: TestFunction, # Callback to run evaluation output_buffer: Optional[MicroBatchDataBuffer] = None, rollout_n: int = 0, mode: str = "pointwise", - in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time + in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time evaluation_test_kwargs: Dict[str, Any] = {}, ): self.rollout_processor = rollout_processor @@ -63,19 +67,21 @@ def __init__( self.eval_executor = eval_executor self.output_buffer = output_buffer self.mode = mode - + # Priority Queue: Stores RolloutTask self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() - + # Concurrency Control self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) - + # Results storage - self.results: List[EvaluationRow] = [] # for backward compatibility reason, we save all results here to return - self.groups_buffer: Dict[int, List[EvaluationRow]] = defaultdict(list) # buffer for group results. only flush to output buffer when a whole group is ready - - self.background_tasks = set() # run evaluations in the background asynchronously - + self.results: List[EvaluationRow] = [] # for backward compatibility reason, we save all results here to return + self.groups_buffer: Dict[int, List[EvaluationRow]] = defaultdict( + list + ) # buffer for group results. only flush to output buffer when a whole group is ready + + self.background_tasks = set() # run evaluations in the background asynchronously + self.rollout_n = rollout_n self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n self.evaluation_test_kwargs = evaluation_test_kwargs @@ -93,17 +99,17 @@ async def schedule_dataset( batch_start = 0 batch_end = min(self.in_group_minibatch_size, self.rollout_n) run_indices = list(range(batch_start, batch_end)) - + # Initial priority: Low (1), ordered by dataset index priority = (1, i) - + task = RolloutTask( priority=priority, row=row, run_indices=run_indices, config=base_config, row_index=i, - history=[] # Initial batch has no history + history=[], # Initial batch has no history ) self.queue.put_nowait(task) @@ -112,7 +118,7 @@ async def worker(self): Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any). """ while True: - # Get a task from the priority queue + # Get a task from the priority queue task: RolloutTask = await self.queue.get() try: @@ -126,13 +132,26 @@ async def _process_task(self, task: RolloutTask): """ Executes a single micro-batch task. """ + async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): """Background evaluation task.""" - rollout_id = rows_to_eval[0].execution_metadata.rollout_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.rollout_id - experiment_id = rows_to_eval[0].execution_metadata.experiment_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.experiment_id - run_id = rows_to_eval[0].execution_metadata.run_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.run_id + rollout_id = ( + rows_to_eval[0].execution_metadata.rollout_id + if isinstance(rows_to_eval, list) + else rows_to_eval.execution_metadata.rollout_id + ) + experiment_id = ( + rows_to_eval[0].execution_metadata.experiment_id + if isinstance(rows_to_eval, list) + else rows_to_eval.execution_metadata.experiment_id + ) + run_id = ( + rows_to_eval[0].execution_metadata.run_id + if isinstance(rows_to_eval, list) + else rows_to_eval.execution_metadata.run_id + ) eval_res = None - + async with self.eval_sem: async with rollout_logging_context( rollout_id or "", @@ -151,7 +170,7 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): evaluation_test_kwargs=self.evaluation_test_kwargs, processed_row=rows_to_eval, ) - + # push result to the output buffer if self.output_buffer: if isinstance(eval_res, list): @@ -161,7 +180,7 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): else: self._post_process_result(eval_res) await self.output_buffer.add_result(eval_res) - + if isinstance(eval_res, list): self.results.extend(eval_res) else: @@ -172,19 +191,19 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): current_batch_rows = [] for run_idx in task.run_indices: row_copy = task.row.model_copy(deep=True) - + row_copy.execution_metadata.run_id = generate_id() row_copy.execution_metadata.rollout_id = generate_id() if row_copy.execution_metadata.extra is None: row_copy.execution_metadata.extra = {} row_copy.execution_metadata.extra["run_index"] = run_idx - + # Inject Speculation History if ENABLE_SPECULATION and task.history: cp = row_copy.input_metadata.completion_params max_tokens = cp.get("max_tokens", 2048) # Ensure safe dict access - if not isinstance(cp, dict): + if not isinstance(cp, dict): cp = {} # Need to check and initialize nested dicts extra_body = cp.get("extra_body") @@ -196,25 +215,22 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): extra_body["prediction"] = {"type": "content", "content": " ".join(task.history)[:max_tokens]} cp["extra_body"] = extra_body row_copy.input_metadata.completion_params = cp - + current_batch_rows.append((run_idx, row_copy)) self.active_logger.log(row_copy) - # 2. Execute Rollout batch_results: List[EvaluationRow] = [] if current_batch_rows: for idx, row in current_batch_rows: - async for result_row in rollout_processor_with_retry( - self.rollout_processor, [row], task.config, idx - ): + async for result_row in rollout_processor_with_retry(self.rollout_processor, [row], task.config, idx): batch_results.append(result_row) # in pointwise, we start evaluation immediately if self.mode == "pointwise": t = asyncio.create_task(_run_eval(result_row)) self.background_tasks.add(t) t.add_done_callback(self.background_tasks.discard) - + # 3. Evaluate and Collect History current_batch_history_updates = [] # Extract history from rollout results (assuming eval doesn't change content needed for history) @@ -229,31 +245,31 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): # in groupwise, we send all rows to evaluator in one go when the whole group is complete if self.mode == "groupwise": self.groups_buffer[task.row_index].extend(batch_results) - if len(self.groups_buffer[task.row_index]) >= self.rollout_n: - full_group = self.groups_buffer.pop(task.row_index) - t = asyncio.create_task(_run_eval(full_group)) - self.background_tasks.add(t) - t.add_done_callback(self.background_tasks.discard) + if len(self.groups_buffer[task.row_index]) >= self.rollout_n: + full_group = self.groups_buffer.pop(task.row_index) + t = asyncio.create_task(_run_eval(full_group)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) # 4. Schedule Next Micro-batch (High Priority) last_run_idx = task.run_indices[-1] if task.run_indices else -1 next_start = last_run_idx + 1 - + if next_start < self.rollout_n: next_end = min(next_start + self.in_group_minibatch_size, self.rollout_n) next_indices = list(range(next_start, next_end)) new_history = task.history + current_batch_history_updates - + # Priority 0 (High) to ensure we finish this sample ASAP new_priority = (0, task.row_index) - + new_task = RolloutTask( priority=new_priority, row=task.row, run_indices=next_indices, config=task.config, row_index=task.row_index, - history=new_history + history=new_history, ) self.queue.put_nowait(new_task) @@ -264,14 +280,10 @@ def _post_process_result(self, res: EvaluationRow): add_cost_metrics(res) if res.eval_metadata is not None: if res.rollout_status.is_error(): - res.eval_metadata.status = Status.error( - res.rollout_status.message, res.rollout_status.details - ) - elif not ( - res.eval_metadata.status and res.eval_metadata.status.code != Status.Code.RUNNING - ): + res.eval_metadata.status = Status.error(res.rollout_status.message, res.rollout_status.details) + elif not (res.eval_metadata.status and res.eval_metadata.status.code != Status.Code.RUNNING): res.eval_metadata.status = Status.eval_finished() - + if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": try: preview = [ @@ -293,33 +305,34 @@ def _post_process_result(self, res: EvaluationRow): async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: RolloutProcessorConfig): self.num_runs = num_runs - + # 1. Schedule initial tasks await self.schedule_dataset(dataset, base_config) - + # 2. Start Workers # If we have separate limits, we need enough workers to saturate both stages num_workers = self.max_concurrent_rollouts workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] - + # 3. Wait for completion await self.queue.join() - + # Wait for background evaluations to finish if self.background_tasks: await asyncio.gather(*self.background_tasks, return_exceptions=True) - + # 4. Cleanup for w in workers: w.cancel() - + if workers: await asyncio.gather(*workers, return_exceptions=True) - + # Return collected results return self.results + async def execute_priority_rollouts( dataset: List[EvaluationRow], num_runs: int, diff --git a/eval_protocol/pytest/rollout_result_post_processor.py b/eval_protocol/pytest/rollout_result_post_processor.py index cdaa98d5..e54175bd 100644 --- a/eval_protocol/pytest/rollout_result_post_processor.py +++ b/eval_protocol/pytest/rollout_result_post_processor.py @@ -54,4 +54,3 @@ def process(self, result: EvaluationRow) -> None: result: The EvaluationRow result from the rollout """ pass - diff --git a/pyproject.toml b/pyproject.toml index a43f773a..4544871e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "litellm<1.75.0", "pytest>=6.0.0", "pytest-asyncio>=0.21.0", - "peewee>=3.18.2", + "tinydb>=4.8.0", + "tinyrecord>=0.2.0", "backoff>=2.2.0", "questionary>=2.0.0", # Dependencies for vendored tau2 package @@ -133,6 +134,9 @@ braintrust = [ openenv = [ "openenv-core", ] +sqlite_storage = [ + "peewee>=3.18.2", +] # Optional deps for LangGraph example/tests langgraph = [ @@ -217,8 +221,8 @@ combine-as-imports = true [tool.pyright] typeCheckingMode = "basic" # Changed from "standard" to reduce memory usage pythonVersion = "3.10" -include = ["eval_protocol"] # Reduced scope to just the main package -exclude = ["vite-app", "vendor", "examples", "tests", "development", "local_evals"] +include = ["eval_protocol", "tests"] +exclude = ["vite-app", "vendor", "examples", "development", "local_evals"] # Ignore diagnostics for vendored generator code ignore = ["versioneer.py"] reportUnusedCallResult = "none" diff --git a/tests/conftest.py b/tests/conftest.py index 9c93cbf8..12f80ff5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import os import sys from pathlib import Path @@ -15,3 +16,56 @@ # Decorator to skip E2B tests when E2B is not available skip_e2b = pytest.mark.skipif(not _HAS_E2B, reason="E2B not installed") + + +# ============================================================================ +# Test isolation for TinyDB storage +# ============================================================================ +# Each test session gets an isolated .eval_protocol directory to prevent +# concurrent test workers from corrupting the shared logs.json file. +# This is especially important in CI where pytest-xdist runs tests in parallel. + +# Store the original function before any patching +import eval_protocol.directory_utils as dir_utils + +_original_find_eval_protocol_dir = dir_utils.find_eval_protocol_dir + + +@pytest.fixture(scope="session", autouse=True) +def isolated_eval_protocol_dir(tmp_path_factory, request): + """ + Create an isolated .eval_protocol directory for the test session. + + This prevents concurrent test workers from corrupting the shared + ~/.eval_protocol/logs.json file when using TinyDB storage. + + Note: Tests in test_directory_utils.py are excluded from this fixture + as they need to test the actual find_eval_protocol_dir behavior. + """ + # Create a unique temp directory for this test session/worker + isolated_dir = tmp_path_factory.mktemp("eval_protocol") + + def isolated_find_eval_protocol_dir() -> str: + os.makedirs(str(isolated_dir), exist_ok=True) + return str(isolated_dir) + + dir_utils.find_eval_protocol_dir = isolated_find_eval_protocol_dir + + yield isolated_dir + + # Restore original function after tests + dir_utils.find_eval_protocol_dir = _original_find_eval_protocol_dir + + +@pytest.fixture +def restore_original_find_eval_protocol_dir(): + """ + Fixture to restore the original find_eval_protocol_dir for tests that + need to test the actual implementation (e.g., test_directory_utils.py). + + Use this fixture in tests that need to test the real directory behavior. + """ + # Temporarily restore the original function + dir_utils.find_eval_protocol_dir = _original_find_eval_protocol_dir + yield _original_find_eval_protocol_dir + # The session fixture will clean up when tests complete diff --git a/tests/data_loader/test_data_loader_stable_row_id.py b/tests/data_loader/test_data_loader_stable_row_id.py index d9aaab96..97f3051d 100644 --- a/tests/data_loader/test_data_loader_stable_row_id.py +++ b/tests/data_loader/test_data_loader_stable_row_id.py @@ -3,9 +3,11 @@ from eval_protocol.pytest import evaluation_test from typing import List + def generator() -> list[EvaluationRow]: return [EvaluationRow(messages=[Message(role="user", content="What is 2 + 2?")]) for _ in range(2)] + @evaluation_test( data_loaders=DynamicDataLoader( generators=[generator], diff --git a/tests/pytest/datasets/klavis_mcp_test.jsonl b/tests/pytest/datasets/klavis_mcp_test.jsonl index 9cee59a7..fd861350 100644 --- a/tests/pytest/datasets/klavis_mcp_test.jsonl +++ b/tests/pytest/datasets/klavis_mcp_test.jsonl @@ -12,4 +12,3 @@ {"messages": [ { "role": "system", "content": "You are a helpful assistant that can answer questions about Outlook Calendar. You have access to Outlook Calendar to help you find information." }, { "role": "user", "content": "How many events do I have during business days of the week of Oct 15 2025?" } ], "ground_truth": "9" } {"messages": [ { "role": "system", "content": "You are a helpful assistant that can answer questions about Outlook Calendar. You have access to Outlook Calendar to help you find information." }, { "role": "user", "content": "How many events do I have on next week's Thursday?" } ], "ground_truth": "2" } {"messages": [ { "role": "system", "content": "You are a helpful assistant that can answer questions about Outlook Calendar. You have access to Outlook Calendar to help you find information." }, { "role": "user", "content": "How many events do I have on next week's buisiness day?" } ], "ground_truth": "5" } - diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index 9f46b7a3..1dd0331e 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -15,17 +15,17 @@ async def test_ensure_logging(monkeypatch): monkeypatch.setattr(_dl, "_logger", None, raising=False) except Exception: pass - # Mock the SqliteEvaluationRowStore to track calls + # Mock the EvaluationRowStore to track calls mock_store = Mock() mock_store.upsert_row = Mock() mock_store.read_rows = Mock(return_value=[]) mock_store.db_path = "/tmp/test.db" - # Mock the SqliteEvaluationRowStore constructor so that when SqliteDatasetLoggerAdapter - # creates its store, it gets our mock instead - with patch( - "eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store - ): + # Mock get_evaluation_row_store so that when DatasetLoggerAdapter + # creates its store, it gets our mock instead. + # We patch at the module level where it's defined, which is where + # dataset_logger_adapter imports it from. + with patch("eval_protocol.dataset_logger.get_evaluation_row_store", return_value=mock_store): from eval_protocol.models import EvaluationRow, EvaluateResult from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor from eval_protocol.pytest.evaluation_test import evaluation_test @@ -55,7 +55,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: ) # Verify that the store's upsert_row method was called - assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called" + assert mock_store.upsert_row.called, "EvaluationRowStore.upsert_row should have been called" # Check that it was called multiple times (once for each row) call_count = mock_store.upsert_row.call_count diff --git a/tests/pytest/test_rollout_scheduler.py b/tests/pytest/test_rollout_scheduler.py index 1a1ff7a9..4cb7004c 100644 --- a/tests/pytest/test_rollout_scheduler.py +++ b/tests/pytest/test_rollout_scheduler.py @@ -2,6 +2,7 @@ from eval_protocol.models import EvaluationRow, Message, EvaluateResult, InputMetadata from typing import List + @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_rows=[ @@ -44,6 +45,6 @@ def test_rollout_scheduler(row: EvaluationRow) -> EvaluationRow: mode="groupwise", ) def test_rollout_scheduler_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]: - for i,row in enumerate(rows): + for i, row in enumerate(rows): row.evaluation_result = EvaluateResult(score=0.1 * i, reason="Dummy evaluation result") - return rows \ No newline at end of file + return rows diff --git a/tests/test_directory_utils.py b/tests/test_directory_utils.py index dcae2cdc..e2daa676 100644 --- a/tests/test_directory_utils.py +++ b/tests/test_directory_utils.py @@ -1,9 +1,22 @@ import os import tempfile from unittest.mock import patch + import pytest -from eval_protocol.directory_utils import find_eval_protocol_dir, find_eval_protocol_datasets_dir +import eval_protocol.directory_utils as dir_utils + + +@pytest.fixture(autouse=True) +def use_real_directory_utils(restore_original_find_eval_protocol_dir): + """ + Automatically use the real find_eval_protocol_dir for all tests in this module. + + This is necessary because the session-scoped isolated_eval_protocol_dir fixture + patches find_eval_protocol_dir globally, but these tests need to test the + actual implementation behavior. + """ + yield class TestDirectoryUtils: @@ -13,7 +26,7 @@ def test_find_eval_protocol_dir_uses_home_folder(self): """Test that find_eval_protocol_dir always maps to home folder.""" with tempfile.TemporaryDirectory() as temp_dir: with patch.dict(os.environ, {"HOME": temp_dir}): - result = find_eval_protocol_dir() + result = dir_utils.find_eval_protocol_dir() expected = os.path.expanduser("~/.eval_protocol") assert result == expected assert result.endswith(".eval_protocol") @@ -29,7 +42,7 @@ def test_find_eval_protocol_dir_creates_directory(self): os.rmdir(eval_protocol_dir) # Call the function - result = find_eval_protocol_dir() + result = dir_utils.find_eval_protocol_dir() # Verify the directory was created assert result == eval_protocol_dir @@ -40,7 +53,7 @@ def test_find_eval_protocol_dir_handles_tilde_expansion(self): """Test that find_eval_protocol_dir properly handles tilde expansion.""" with tempfile.TemporaryDirectory() as temp_dir: with patch.dict(os.environ, {"HOME": temp_dir}): - result = find_eval_protocol_dir() + result = dir_utils.find_eval_protocol_dir() expected = os.path.expanduser("~/.eval_protocol") assert result == expected assert result.startswith(temp_dir) @@ -49,7 +62,7 @@ def test_find_eval_protocol_datasets_dir_uses_home_folder(self): """Test that find_eval_protocol_datasets_dir also uses home folder.""" with tempfile.TemporaryDirectory() as temp_dir: with patch.dict(os.environ, {"HOME": temp_dir}): - result = find_eval_protocol_datasets_dir() + result = dir_utils.find_eval_protocol_datasets_dir() expected = os.path.expanduser("~/.eval_protocol/datasets") assert result == expected assert result.endswith(".eval_protocol/datasets") @@ -69,7 +82,7 @@ def test_find_eval_protocol_datasets_dir_creates_directory(self): os.rmdir(eval_protocol_dir) # Call the function - result = find_eval_protocol_datasets_dir() + result = dir_utils.find_eval_protocol_datasets_dir() # Verify both directories were created assert result == datasets_dir @@ -82,14 +95,14 @@ def test_find_eval_protocol_dir_consistency(self): """Test that multiple calls to find_eval_protocol_dir return the same path.""" with tempfile.TemporaryDirectory() as temp_dir: with patch.dict(os.environ, {"HOME": temp_dir}): - result1 = find_eval_protocol_dir() - result2 = find_eval_protocol_dir() + result1 = dir_utils.find_eval_protocol_dir() + result2 = dir_utils.find_eval_protocol_dir() assert result1 == result2 def test_find_eval_protocol_datasets_dir_consistency(self): """Test that multiple calls to find_eval_protocol_datasets_dir return the same path.""" with tempfile.TemporaryDirectory() as temp_dir: with patch.dict(os.environ, {"HOME": temp_dir}): - result1 = find_eval_protocol_datasets_dir() - result2 = find_eval_protocol_datasets_dir() + result1 = dir_utils.find_eval_protocol_datasets_dir() + result2 = dir_utils.find_eval_protocol_datasets_dir() assert result1 == result2 diff --git a/tests/test_exception_config.py b/tests/test_exception_config.py index 90db182a..77d1696d 100644 --- a/tests/test_exception_config.py +++ b/tests/test_exception_config.py @@ -17,11 +17,11 @@ def test_backoff_config_no_exceptions(): """Test that BackoffConfig returns no-op decorator when no exceptions specified.""" config = BackoffConfig() decorator = config.get_backoff_decorator(set()) - + # Should be a no-op decorator def test_func(): return "test" - + decorated = decorator(test_func) assert decorated() == "test" assert decorated is test_func # Should be the same function @@ -31,14 +31,14 @@ def test_backoff_config_no_overrides(): """Test that BackoffConfig creates a single decorator.""" config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2) exceptions = {ConnectionError, TimeoutError} - + decorator = config.get_backoff_decorator(exceptions) assert decorator is not None - + # Decorator should be callable def test_func(): raise ConnectionError("test") - + decorated = decorator(test_func) assert callable(decorated) @@ -46,7 +46,7 @@ def test_func(): def test_exception_handler_config_default_response_quality_error(): """Test that ExceptionHandlerConfig includes ResponseQualityError by default.""" config = ExceptionHandlerConfig() - + # ResponseQualityError should be in retryable_exceptions assert ResponseQualityError in config.retryable_exceptions @@ -55,30 +55,29 @@ def test_exception_handler_config_get_backoff_decorator(): """Test that ExceptionHandlerConfig.get_backoff_decorator() works correctly.""" config = ExceptionHandlerConfig() decorator = config.get_backoff_decorator() - + assert decorator is not None assert callable(decorator) - + # Should be able to decorate a function def test_func(): raise ConnectionError("test") - + decorated = decorator(test_func) assert callable(decorated) def test_backoff_config_expo_strategy(): - """Test that BackoffConfig creates expo decorator correctly.""" config = BackoffConfig(strategy="expo", base_delay=1.0, max_tries=2) exceptions = {ConnectionError} - + decorator = config.get_backoff_decorator(exceptions) assert decorator is not None - + def test_func(): raise ConnectionError("test") - + decorated = decorator(test_func) assert callable(decorated) @@ -87,13 +86,13 @@ def test_backoff_config_constant_strategy(): """Test that BackoffConfig creates constant decorator correctly.""" config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2) exceptions = {ConnectionError} - + decorator = config.get_backoff_decorator(exceptions) assert decorator is not None - + def test_func(): raise ConnectionError("test") - + decorated = decorator(test_func) assert callable(decorated) @@ -102,7 +101,7 @@ def test_backoff_config_invalid_strategy(): """Test that BackoffConfig raises ValueError for invalid strategy.""" config = BackoffConfig(strategy="invalid", base_delay=1.0, max_tries=2) exceptions = {ConnectionError} - + with pytest.raises(ValueError, match="Unknown backoff strategy"): config.get_backoff_decorator(exceptions) @@ -110,5 +109,3 @@ def test_backoff_config_invalid_strategy(): def test_exception_handler_config_response_quality_error_in_defaults(): """Test that ResponseQualityError is in DEFAULT_RETRYABLE_EXCEPTIONS.""" assert ResponseQualityError in DEFAULT_RETRYABLE_EXCEPTIONS - - diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py index 27e748eb..9dd95bce 100644 --- a/tests/test_priority_scheduler.py +++ b/tests/test_priority_scheduler.py @@ -9,53 +9,56 @@ from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.dataset_logger.dataset_logger import DatasetLogger + # Mock models def create_mock_row(row_id: str = "test-row") -> EvaluationRow: return EvaluationRow( - input_metadata=InputMetadata( - row_id=row_id, - completion_params={"model": "test-model"} - ), - execution_metadata=ExecutionMetadata() + input_metadata=InputMetadata(row_id=row_id, completion_params={"model": "test-model"}), + execution_metadata=ExecutionMetadata(), ) + @pytest.fixture def mock_rollout_processor(): processor = MagicMock() + # Mocking the rollout to be an async generator async def mock_rollout_gen(rows, config, run_idx): for row in rows: # Simulate some work yield row + processor.side_effect = mock_rollout_gen return processor + @pytest.fixture def mock_logger(): return MagicMock(spec=DatasetLogger) + @pytest.fixture def mock_eval_executor(): return AsyncMock() + @pytest.fixture def base_config(): return RolloutProcessorConfig( completion_params={"model": "test-model"}, mcp_config_path="test_config.yaml", semaphore=asyncio.Semaphore(10), - steps=10 + steps=10, ) + @pytest.mark.asyncio -async def test_scheduler_basic_execution( - mock_logger, mock_eval_executor, base_config -): +async def test_scheduler_basic_execution(mock_logger, mock_eval_executor, base_config): """Test that the scheduler processes all rows and completes.""" dataset = [create_mock_row(f"row-{i}") for i in range(5)] num_runs = 2 micro_batch_size = 1 - + # Mock rollout processor with delay async def delayed_rollout(processor, rows, config, run_idx): await asyncio.sleep(0.01) @@ -66,9 +69,9 @@ async def mock_eval(row): row.evaluation_result = EvaluateResult(score=1.0, is_score_valid=True) return row - with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=delayed_rollout): + with patch("eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry", side_effect=delayed_rollout): processor_instance = MagicMock() - + scheduler = PriorityRolloutScheduler( rollout_processor=processor_instance, max_concurrent_rollouts=2, @@ -76,11 +79,11 @@ async def mock_eval(row): eval_executor=mock_eval, max_concurrent_evaluations=2, rollout_n=num_runs, - in_group_minibatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size, ) - + results = await scheduler.run(dataset, num_runs, base_config) - + assert len(results) == 5 * num_runs for res in results: assert res.evaluation_result is not None @@ -88,25 +91,23 @@ async def mock_eval(row): @pytest.mark.asyncio -async def test_concurrency_control( - mock_logger, mock_eval_executor, base_config -): +async def test_concurrency_control(mock_logger, mock_eval_executor, base_config): """ Verify that max_concurrent_rollouts and max_concurrent_evaluations are respected. """ dataset = [create_mock_row(f"row-{i}") for i in range(10)] num_runs = 1 micro_batch_size = 1 - + max_rollouts = 4 max_evals = 2 - + active_rollouts = 0 max_active_rollouts_seen = 0 - + active_evals = 0 max_active_evals_seen = 0 - + rollout_lock = asyncio.Lock() eval_lock = asyncio.Lock() @@ -115,13 +116,13 @@ async def mock_rollout_gen(processor, rows, config, run_idx): async with rollout_lock: active_rollouts += 1 max_active_rollouts_seen = max(max_active_rollouts_seen, active_rollouts) - + # Simulate slow rollout await asyncio.sleep(0.05) - + for row in rows: yield row - + async with rollout_lock: active_rollouts -= 1 @@ -131,19 +132,18 @@ async def mock_eval(row): async with eval_lock: active_evals += 1 max_active_evals_seen = max(max_active_evals_seen, active_evals) - + # Simulate evaluation await asyncio.sleep(0.05) - + async with eval_lock: active_evals -= 1 return row - with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): - + with patch("eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry", side_effect=mock_rollout_gen): # Mock processor instance (can be anything since we patched the wrapper) processor_instance = MagicMock() - + scheduler = PriorityRolloutScheduler( rollout_processor=processor_instance, max_concurrent_rollouts=max_rollouts, @@ -151,32 +151,33 @@ async def mock_eval(row): eval_executor=mock_eval, max_concurrent_evaluations=max_evals, rollout_n=num_runs, - in_group_minibatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size, ) - + await scheduler.run(dataset, num_runs, base_config) - + # Verify limits were respected - assert max_active_rollouts_seen <= max_rollouts, f"Rollout concurrency exceeded: {max_active_rollouts_seen} > {max_rollouts}" + assert max_active_rollouts_seen <= max_rollouts, ( + f"Rollout concurrency exceeded: {max_active_rollouts_seen} > {max_rollouts}" + ) assert max_active_evals_seen <= max_evals, f"Eval concurrency exceeded: {max_active_evals_seen} > {max_evals}" - + # Verify everything ran # 10 rows * 1 run = 10 results assert len(scheduler.results) == 10 + @pytest.mark.asyncio -async def test_priority_scheduling( - mock_logger, mock_eval_executor, base_config -): +async def test_priority_scheduling(mock_logger, mock_eval_executor, base_config): """ Test that subsequent micro-batches are prioritized. """ dataset = [create_mock_row(f"row-{i}") for i in range(2)] num_runs = 2 micro_batch_size = 1 - + execution_order = [] - + async def mock_rollout_gen(processor, rows, config, run_idx): row_id = rows[0].input_metadata.row_id execution_order.append(f"{row_id}_run_{run_idx}") @@ -186,38 +187,32 @@ async def mock_rollout_gen(processor, rows, config, run_idx): async def mock_eval(row): return row - with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + with patch("eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry", side_effect=mock_rollout_gen): processor_instance = MagicMock() - + scheduler = PriorityRolloutScheduler( rollout_processor=processor_instance, - max_concurrent_rollouts=1, # Force serial execution to test priority + max_concurrent_rollouts=1, # Force serial execution to test priority active_logger=mock_logger, eval_executor=mock_eval, max_concurrent_evaluations=1, rollout_n=num_runs, - in_group_minibatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size, ) - + await scheduler.run(dataset, num_runs, base_config) - + # Expected order: row-0_run_0, row-0_run_1, row-1_run_0, row-1_run_1 # Note: Since row-0_run_0 finishes, it schedules row-0_run_1 with HIGH priority (0). # row-1_run_0 is in queue with LOW priority (1). # So row-0_run_1 should run before row-1_run_0. - expected = [ - "row-0_run_0", - "row-0_run_1", - "row-1_run_0", - "row-1_run_1" - ] - + expected = ["row-0_run_0", "row-0_run_1", "row-1_run_0", "row-1_run_1"] + assert execution_order == expected + @pytest.mark.asyncio -async def test_worker_scaling( - mock_logger, mock_eval_executor, base_config -): +async def test_worker_scaling(mock_logger, mock_eval_executor, base_config): """ Test that the number of workers scales with the sum of limits. """ @@ -226,9 +221,9 @@ async def test_worker_scaling( max_evals = 3 # Updated expectation: workers only scale with rollout concurrency now expected_workers = max_rollouts - + worker_start_count = 0 - + class InstrumentedScheduler(PriorityRolloutScheduler): async def worker(self): nonlocal worker_start_count @@ -242,17 +237,12 @@ async def worker(self): pass async def schedule_dataset(self, *args): - # Put enough items to ensure all workers wake up and grab one - for i in range(expected_workers): - task = RolloutTask( - priority=(1, i), - row=dataset[0], - run_indices=[], - config=base_config, - row_index=0, - history=[] - ) - await self.queue.put(task) + # Put enough items to ensure all workers wake up and grab one + for i in range(expected_workers): + task = RolloutTask( + priority=(1, i), row=dataset[0], run_indices=[], config=base_config, row_index=0, history=[] + ) + await self.queue.put(task) processor_instance = MagicMock() scheduler = InstrumentedScheduler( @@ -262,41 +252,40 @@ async def schedule_dataset(self, *args): eval_executor=mock_eval_executor, max_concurrent_evaluations=max_evals, rollout_n=1, - in_group_minibatch_size=1 + in_group_minibatch_size=1, ) - + await scheduler.run(dataset, 1, base_config) - + assert worker_start_count == expected_workers + @pytest.mark.asyncio -async def test_groupwise_mode( - mock_logger, mock_eval_executor, base_config -): +async def test_groupwise_mode(mock_logger, mock_eval_executor, base_config): """ Test that groupwise mode collects all runs before evaluating. """ dataset = [create_mock_row("row-0")] num_runs = 4 micro_batch_size = 2 - + # We expect 2 batches of 2 runs each. # Batch 1 (Runs 0,1): Should buffer and update history, NOT call eval. # Batch 2 (Runs 2,3): Should buffer, update history, AND call eval with all 4 runs. - + eval_calls = [] - + async def mock_eval(rows): eval_calls.append(rows) - return rows # Pass through + return rows # Pass through async def mock_rollout_gen(processor, rows, config, run_idx): for row in rows: yield row - - with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + + with patch("eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry", side_effect=mock_rollout_gen): processor_instance = MagicMock() - + scheduler = PriorityRolloutScheduler( rollout_processor=processor_instance, max_concurrent_rollouts=1, @@ -305,18 +294,18 @@ async def mock_rollout_gen(processor, rows, config, run_idx): max_concurrent_evaluations=1, mode="groupwise", rollout_n=num_runs, - in_group_minibatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size, ) - + results = await scheduler.run(dataset, num_runs, base_config) - + # Verify evaluation was called EXACTLY ONCE assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}" - + # Verify it was called with ALL 4 rows evaluated_rows = eval_calls[0] assert len(evaluated_rows) == 4, f"Expected 4 rows in group eval, got {len(evaluated_rows)}" - + # Verify results contains all 4 runs (returned from eval) # Note: eval returns a list of 4 rows. scheduler.results extends this list. assert len(results) == 4 diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 861793c1..9b664ecb 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -266,7 +266,7 @@ def custom_http_giveup(e: Exception) -> bool: return True # Give up immediately on bad requests elif isinstance(e, litellm.RateLimitError): return False # Retry rate limits with backoff - + return False # Retry everything else @@ -388,7 +388,7 @@ def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow: def test_simple_giveup_verification(): """Verify that giveup function prevents retries.""" mock_tracker = shared_processor_simple_giveup.mock_tracker - + print("\nšŸ”„ SIMPLE GIVEUP TEST ANALYSIS:") print(f" Batch calls made: {mock_tracker.batch_call.call_count}") print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") @@ -418,7 +418,9 @@ async def process_single_row(row: EvaluationRow) -> EvaluationRow: # Determine attempt number by counting previous calls for this rollout_id previous_calls = [ - call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == row.execution_metadata.rollout_id + call + for call in self.mock_tracker.process_row_call.call_args_list + if call[0][0] == row.execution_metadata.rollout_id ] attempt_number = len(previous_calls) @@ -477,9 +479,6 @@ def test_response_quality_error_verification(): # Should have exactly 1 rollout_id called twice call_count_values = list(call_counts.values()) - assert call_count_values.count(2) == 1, ( - f"Expected 1 rollout with 2 calls, got {call_count_values}" - ) + assert call_count_values.count(2) == 1, f"Expected 1 rollout with 2 calls, got {call_count_values}" print("āœ… ResponseQualityError test passed! Error was retried.") - diff --git a/tests/test_storage_backends.py b/tests/test_storage_backends.py new file mode 100644 index 00000000..9d550cb5 --- /dev/null +++ b/tests/test_storage_backends.py @@ -0,0 +1,477 @@ +""" +Tests for the storage backend abstraction and implementations. + +Tests both TinyDB (default) and SQLite backends to ensure they work correctly +and can be selected via the EP_STORAGE environment variable. +""" + +import os +import tempfile +from typing import Generator + +import pytest + +from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore +from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore +from eval_protocol.dataset_logger.tinydb_evaluation_row_store import TinyDBEvaluationRowStore +from eval_protocol.event_bus.event_bus_database import EventBusDatabase +from eval_protocol.event_bus.sqlite_event_bus_database import SqliteEventBusDatabase +from eval_protocol.event_bus.tinydb_event_bus_database import TinyDBEventBusDatabase + + +@pytest.fixture +def temp_dir() -> Generator[str, None, None]: + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +class TestEvaluationRowStoreABC: + """Tests that both implementations correctly implement the ABC.""" + + def test_sqlite_implements_abc(self, temp_dir: str): + """SqliteEvaluationRowStore should implement EvaluationRowStore.""" + db_path = os.path.join(temp_dir, "test.db") + store = SqliteEvaluationRowStore(db_path) + assert isinstance(store, EvaluationRowStore) + + def test_tinydb_implements_abc(self, temp_dir: str): + """TinyDBEvaluationRowStore should implement EvaluationRowStore.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + assert isinstance(store, EvaluationRowStore) + + +class TestTinyDBEvaluationRowStore: + """Tests for TinyDBEvaluationRowStore.""" + + def test_upsert_and_read_row(self, temp_dir: str): + """Test basic upsert and read operations.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + data = { + "execution_metadata": {"rollout_id": "test-rollout-1"}, + "input_metadata": {"row_id": "row-1"}, + "messages": [{"role": "user", "content": "Hello"}], + } + + store.upsert_row(data) + rows = store.read_rows(rollout_id="test-rollout-1") + + assert len(rows) == 1 + assert rows[0]["execution_metadata"]["rollout_id"] == "test-rollout-1" + assert rows[0]["input_metadata"]["row_id"] == "row-1" + + def test_upsert_updates_existing_row(self, temp_dir: str): + """Test that upsert updates an existing row.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + data1 = { + "execution_metadata": {"rollout_id": "test-rollout-1"}, + "input_metadata": {"row_id": "row-1"}, + "messages": [{"role": "user", "content": "Hello"}], + } + data2 = { + "execution_metadata": {"rollout_id": "test-rollout-1"}, + "input_metadata": {"row_id": "row-1-updated"}, + "messages": [{"role": "user", "content": "Updated"}], + } + + store.upsert_row(data1) + store.upsert_row(data2) + + rows = store.read_rows() + assert len(rows) == 1 + assert rows[0]["input_metadata"]["row_id"] == "row-1-updated" + + def test_read_all_rows(self, temp_dir: str): + """Test reading all rows without filter.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + for i in range(3): + data = { + "execution_metadata": {"rollout_id": f"test-rollout-{i}"}, + "input_metadata": {"row_id": f"row-{i}"}, + } + store.upsert_row(data) + + rows = store.read_rows() + assert len(rows) == 3 + + def test_delete_row(self, temp_dir: str): + """Test deleting a specific row.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + data = { + "execution_metadata": {"rollout_id": "test-rollout-1"}, + "input_metadata": {"row_id": "row-1"}, + } + store.upsert_row(data) + + deleted = store.delete_row("test-rollout-1") + assert deleted == 1 + + rows = store.read_rows() + assert len(rows) == 0 + + def test_delete_row_nonexistent(self, temp_dir: str): + """Test deleting a row that doesn't exist returns 0.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + # Try to delete a row that doesn't exist + deleted = store.delete_row("nonexistent-rollout") + assert deleted == 0 + + # Verify store is still empty + rows = store.read_rows() + assert len(rows) == 0 + + def test_delete_all_rows(self, temp_dir: str): + """Test deleting all rows.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + for i in range(3): + data = { + "execution_metadata": {"rollout_id": f"test-rollout-{i}"}, + "input_metadata": {"row_id": f"row-{i}"}, + } + store.upsert_row(data) + + deleted = store.delete_all_rows() + assert deleted == 3 + + rows = store.read_rows() + assert len(rows) == 0 + + def test_db_path_property(self, temp_dir: str): + """Test that db_path property returns correct path.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + assert store.db_path == db_path + + def test_raises_on_missing_rollout_id(self, temp_dir: str): + """Test that upsert raises when rollout_id is None.""" + db_path = os.path.join(temp_dir, "test.json") + store = TinyDBEvaluationRowStore(db_path) + + data = { + "execution_metadata": {"rollout_id": None}, + "input_metadata": {"row_id": "row-1"}, + } + + with pytest.raises(ValueError, match="rollout_id is required"): + store.upsert_row(data) + + +class TestTinyDBEventBusDatabase: + """Tests for TinyDBEventBusDatabase.""" + + def test_publish_and_get_events(self, temp_dir: str): + """Test publishing and retrieving events.""" + db_path = os.path.join(temp_dir, "events.json") + db = TinyDBEventBusDatabase(db_path) + + db.publish_event("test_event", {"key": "value"}, "process-1") + + # Get events from a different process + events = db.get_unprocessed_events("process-2") + + assert len(events) == 1 + assert events[0]["event_type"] == "test_event" + assert events[0]["data"] == {"key": "value"} + assert events[0]["process_id"] == "process-1" + + def test_events_filtered_by_process_id(self, temp_dir: str): + """Test that events from same process are not returned.""" + db_path = os.path.join(temp_dir, "events.json") + db = TinyDBEventBusDatabase(db_path) + + db.publish_event("test_event", {"key": "value"}, "process-1") + + # Get events from the same process - should be empty + events = db.get_unprocessed_events("process-1") + assert len(events) == 0 + + def test_mark_event_processed(self, temp_dir: str): + """Test marking events as processed.""" + db_path = os.path.join(temp_dir, "events.json") + db = TinyDBEventBusDatabase(db_path) + + db.publish_event("test_event", {"key": "value"}, "process-1") + + events = db.get_unprocessed_events("process-2") + assert len(events) == 1 + + db.mark_event_processed(events[0]["event_id"]) + + # Should no longer be returned + events = db.get_unprocessed_events("process-2") + assert len(events) == 0 + + def test_cleanup_old_events(self, temp_dir: str): + """Test cleaning up old processed events.""" + db_path = os.path.join(temp_dir, "events.json") + db = TinyDBEventBusDatabase(db_path) + + db.publish_event("test_event", {"key": "value"}, "process-1") + events = db.get_unprocessed_events("process-2") + db.mark_event_processed(events[0]["event_id"]) + + # Cleanup with 0 hours should remove all processed events + db.cleanup_old_events(max_age_hours=0) + + # The event should still be gone (processed) + events = db.get_unprocessed_events("process-2") + assert len(events) == 0 + + +class TestEventBusDatabaseABC: + """Tests that both implementations correctly implement the ABC.""" + + def test_sqlite_implements_abc(self, temp_dir: str): + """SqliteEventBusDatabase should implement EventBusDatabase.""" + db_path = os.path.join(temp_dir, "events.db") + db = SqliteEventBusDatabase(db_path) + assert isinstance(db, EventBusDatabase) + + def test_tinydb_implements_abc(self, temp_dir: str): + """TinyDBEventBusDatabase should implement EventBusDatabase.""" + db_path = os.path.join(temp_dir, "events.json") + db = TinyDBEventBusDatabase(db_path) + assert isinstance(db, EventBusDatabase) + + +class TestFactoryFunctions: + """Tests for factory functions that select storage backends.""" + + def test_get_evaluation_row_store_default_tinydb(self, temp_dir: str, monkeypatch): + """Default should be TinyDB.""" + monkeypatch.delenv("EP_STORAGE", raising=False) + + from eval_protocol.dataset_logger import get_evaluation_row_store + + db_path = os.path.join(temp_dir, "test.json") + store = get_evaluation_row_store(db_path) + + assert isinstance(store, TinyDBEvaluationRowStore) + + def test_get_evaluation_row_store_sqlite(self, temp_dir: str, monkeypatch): + """EP_STORAGE=sqlite should use SQLite.""" + monkeypatch.setenv("EP_STORAGE", "sqlite") + + from eval_protocol.dataset_logger import get_evaluation_row_store + + db_path = os.path.join(temp_dir, "test.db") + store = get_evaluation_row_store(db_path) + + assert isinstance(store, SqliteEvaluationRowStore) + + def test_get_event_bus_database_default_tinydb(self, temp_dir: str, monkeypatch): + """Default should be TinyDB.""" + monkeypatch.delenv("EP_STORAGE", raising=False) + + from eval_protocol.event_bus import get_event_bus_database + + db_path = os.path.join(temp_dir, "events.json") + db = get_event_bus_database(db_path) + + assert isinstance(db, TinyDBEventBusDatabase) + + def test_get_event_bus_database_sqlite(self, temp_dir: str, monkeypatch): + """EP_STORAGE=sqlite should use SQLite.""" + monkeypatch.setenv("EP_STORAGE", "sqlite") + + from eval_protocol.event_bus import get_event_bus_database + + db_path = os.path.join(temp_dir, "events.db") + db = get_event_bus_database(db_path) + + assert isinstance(db, SqliteEventBusDatabase) + + +class TestCrossProcessCacheInvalidation: + """ + Tests that query cache is properly invalidated when another process modifies the database. + + This simulates cross-process scenarios by creating separate store instances + pointing to the same database file. Each instance represents a different "process" + that might have cached query results. + """ + + @pytest.mark.parametrize( + "store_class,file_ext", + [ + (TinyDBEvaluationRowStore, ".json"), + (SqliteEvaluationRowStore, ".db"), + ], + ) + def test_evaluation_row_store_sees_writes_from_other_process(self, temp_dir: str, store_class, file_ext: str): + """ + Ensure a store instance can read fresh data written by another instance. + + This verifies that cached query results don't prevent seeing new data + written by a separate process. + """ + db_path = os.path.join(temp_dir, f"test{file_ext}") + + # Simulate two processes with separate store instances + process1_store = store_class(db_path) + process2_store = store_class(db_path) + + # Process 1 reads initially (may cache empty result) + initial_rows = process1_store.read_rows() + assert len(initial_rows) == 0 + + # Process 2 writes new data + data = { + "execution_metadata": {"rollout_id": "cross-process-test"}, + "input_metadata": {"row_id": "row-from-process-2"}, + "messages": [{"role": "user", "content": "Hello from process 2"}], + } + process2_store.upsert_row(data) + + # Process 1 should see the new data (cache should be invalidated/bypassed) + rows = process1_store.read_rows() + assert len(rows) == 1 + assert rows[0]["execution_metadata"]["rollout_id"] == "cross-process-test" + assert rows[0]["input_metadata"]["row_id"] == "row-from-process-2" + + @pytest.mark.parametrize( + "store_class,file_ext", + [ + (TinyDBEvaluationRowStore, ".json"), + (SqliteEvaluationRowStore, ".db"), + ], + ) + def test_evaluation_row_store_sees_updates_from_other_process(self, temp_dir: str, store_class, file_ext: str): + """ + Ensure a store instance sees updates made by another instance. + + This verifies that cached query results are properly invalidated + when another process updates existing data. + """ + db_path = os.path.join(temp_dir, f"test{file_ext}") + + # Both processes start with the same initial data + process1_store = store_class(db_path) + initial_data = { + "execution_metadata": {"rollout_id": "shared-rollout"}, + "input_metadata": {"row_id": "initial-row"}, + "value": "initial", + } + process1_store.upsert_row(initial_data) + + # Process 2 opens the same database + process2_store = store_class(db_path) + + # Process 1 reads and potentially caches the result + rows = process1_store.read_rows(rollout_id="shared-rollout") + assert len(rows) == 1 + assert rows[0]["value"] == "initial" + + # Process 2 updates the data + updated_data = { + "execution_metadata": {"rollout_id": "shared-rollout"}, + "input_metadata": {"row_id": "updated-row"}, + "value": "updated-by-process-2", + } + process2_store.upsert_row(updated_data) + + # Process 1 should see the updated data + rows = process1_store.read_rows(rollout_id="shared-rollout") + assert len(rows) == 1 + assert rows[0]["value"] == "updated-by-process-2" + assert rows[0]["input_metadata"]["row_id"] == "updated-row" + + @pytest.mark.parametrize( + "db_class,file_ext", + [ + (TinyDBEventBusDatabase, ".json"), + (SqliteEventBusDatabase, ".db"), + ], + ) + def test_event_bus_database_sees_events_from_other_process(self, temp_dir: str, db_class, file_ext: str): + """ + Ensure an event bus instance can read events published by another instance. + + This verifies that cached query results don't prevent seeing new events + written by a separate process. + """ + db_path = os.path.join(temp_dir, f"events{file_ext}") + + # Simulate two processes with separate event bus instances + process1_db = db_class(db_path) + process2_db = db_class(db_path) + + # Process 1 checks for events initially (may cache empty result) + initial_events = process1_db.get_unprocessed_events("process-1") + assert len(initial_events) == 0 + + # Process 2 publishes an event + process2_db.publish_event("test_event", {"key": "value"}, "process-2") + + # Process 1 should see the new event (cache should be invalidated/bypassed) + events = process1_db.get_unprocessed_events("process-1") + assert len(events) == 1 + assert events[0]["event_type"] == "test_event" + assert events[0]["data"] == {"key": "value"} + assert events[0]["process_id"] == "process-2" + + @pytest.mark.parametrize( + "db_class,file_ext", + [ + (TinyDBEventBusDatabase, ".json"), + (SqliteEventBusDatabase, ".db"), + ], + ) + def test_event_bus_database_sees_processed_status_from_other_process(self, temp_dir: str, db_class, file_ext: str): + """ + Ensure an event bus instance sees when another instance marks events as processed. + + This verifies that cached query results are properly invalidated + when another process updates event status. + """ + db_path = os.path.join(temp_dir, f"events{file_ext}") + + # Process 1 publishes an event + process1_db = db_class(db_path) + process1_db.publish_event("test_event", {"key": "value"}, "process-1") + + # Process 2 opens the same database and sees the event + process2_db = db_class(db_path) + events = process2_db.get_unprocessed_events("process-2") + assert len(events) == 1 + + # Process 3 opens and marks the event as processed + process3_db = db_class(db_path) + events = process3_db.get_unprocessed_events("process-3") + assert len(events) == 1 + process3_db.mark_event_processed(events[0]["event_id"]) + + # Process 2 should no longer see the event (it's been processed) + events = process2_db.get_unprocessed_events("process-2") + assert len(events) == 0 + + +class TestBackwardsCompatibility: + """Tests for backwards compatibility aliases.""" + + def test_sqlite_dataset_logger_adapter_alias(self): + """SqliteDatasetLoggerAdapter should be an alias for DatasetLoggerAdapter.""" + from eval_protocol.dataset_logger.dataset_logger_adapter import DatasetLoggerAdapter + from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter + + assert SqliteDatasetLoggerAdapter is DatasetLoggerAdapter + + def test_sqlite_event_bus_alias(self): + """SqliteEventBus should be an alias for CrossProcessEventBus.""" + from eval_protocol.event_bus.cross_process_event_bus import CrossProcessEventBus + from eval_protocol.event_bus.sqlite_event_bus import SqliteEventBus + + assert SqliteEventBus is CrossProcessEventBus diff --git a/uv.lock b/uv.lock index 38b07c4a..7d8c425b 100644 --- a/uv.lock +++ b/uv.lock @@ -1228,7 +1228,6 @@ dependencies = [ { name = "mcp" }, { name = "omegaconf" }, { name = "openai" }, - { name = "peewee" }, { name = "psutil" }, { name = "pydantic" }, { name = "pytest" }, @@ -1238,6 +1237,8 @@ dependencies = [ { name = "questionary" }, { name = "requests" }, { name = "rich" }, + { name = "tinydb" }, + { name = "tinyrecord" }, { name = "toml" }, { name = "uvicorn" }, { name = "websockets" }, @@ -1324,6 +1325,9 @@ proxy = [ pydantic = [ { name = "pydantic-ai" }, ] +sqlite-storage = [ + { name = "peewee" }, +] supabase = [ { name = "supabase" }, ] @@ -1392,7 +1396,7 @@ requires-dist = [ { name = "openenv-core", marker = "extra == 'openenv'" }, { name = "openevals", marker = "extra == 'openevals'", specifier = ">=0.1.0" }, { name = "pandas", marker = "extra == 'dev'", specifier = ">=1.5.0" }, - { name = "peewee", specifier = ">=3.18.2" }, + { name = "peewee", marker = "extra == 'sqlite-storage'", specifier = ">=3.18.2" }, { name = "peft", marker = "extra == 'trl'", specifier = ">=0.7.0" }, { name = "pillow", marker = "extra == 'box2d'" }, { name = "pip", marker = "extra == 'dev'", specifier = ">=25.1.1" }, @@ -1417,6 +1421,8 @@ requires-dist = [ { name = "supabase", marker = "extra == 'supabase'", specifier = ">=2.18.1" }, { name = "swig", marker = "extra == 'box2d'" }, { name = "syrupy", marker = "extra == 'dev'", specifier = ">=4.0.0" }, + { name = "tinydb", specifier = ">=4.8.0" }, + { name = "tinyrecord", specifier = ">=0.2.0" }, { name = "toml", specifier = ">=0.10.0" }, { name = "torch", marker = "extra == 'trl'", specifier = ">=1.9" }, { name = "transformers", marker = "extra == 'dev'", specifier = ">=4.0.0" }, @@ -1434,7 +1440,7 @@ requires-dist = [ { name = "websockets", specifier = ">=15.0.1" }, { name = "werkzeug", marker = "extra == 'dev'", specifier = ">=2.0.0" }, ] -provides-extras = ["dev", "trl", "openevals", "fireworks", "box2d", "langfuse", "huggingface", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "braintrust", "openenv", "langgraph", "langgraph-tools", "proxy"] +provides-extras = ["dev", "trl", "openevals", "fireworks", "box2d", "langfuse", "huggingface", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "braintrust", "openenv", "sqlite-storage", "langgraph", "langgraph-tools", "proxy"] [package.metadata.requires-dev] dev = [ @@ -6512,6 +6518,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" }, ] +[[package]] +name = "tinydb" +version = "4.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/79/4af51e2bb214b6ea58f857c51183d92beba85b23f7ba61c983ab3de56c33/tinydb-4.8.2.tar.gz", hash = "sha256:f7dfc39b8d7fda7a1ca62a8dbb449ffd340a117c1206b68c50b1a481fb95181d", size = 32566, upload-time = "2024-10-12T15:24:01.13Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/17/853354204e1ca022d6b7d011ca7f3206c4f8faa3cc743e92609b49c1d83f/tinydb-4.8.2-py3-none-any.whl", hash = "sha256:f97030ee5cbc91eeadd1d7af07ab0e48ceb04aa63d4a983adbaca4cba16e86c3", size = 24888, upload-time = "2024-10-12T15:23:59.833Z" }, +] + +[[package]] +name = "tinyrecord" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/6c/a004d025d9bf1d8af014fd6f22bd6e49e4533d6fbac54f7d4ae7acb3ef5f/tinyrecord-0.2.0.tar.gz", hash = "sha256:eb6dc23601be359ee00f5a3d31a46adf3bad0a16f8d60af216cd67982ca75cf4", size = 5556, upload-time = "2020-07-05T15:43:06.056Z" } + [[package]] name = "tokenizers" version = "0.21.2"