Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions eval_protocol/dataset_logger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,45 @@
import os

from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore


# Allow disabling sqlite logger to avoid environment-specific constraints in simple CLI runs.
def _get_default_logger():
if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() != "1":
return SqliteDatasetLoggerAdapter()
def get_evaluation_row_store(db_path: str) -> EvaluationRowStore:
"""
Factory to get the configured storage backend.

Uses EP_STORAGE environment variable to select backend:
- "tinydb" (default): Uses TinyDB with JSON file storage
- "sqlite": Uses SQLite with peewee ORM

Args:
db_path: Path to the database file

Returns:
EvaluationRowStore implementation
"""
storage_type = os.getenv("EP_STORAGE", "tinydb").lower()

if storage_type == "sqlite":
from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore

return SqliteEvaluationRowStore(db_path)
else:
from eval_protocol.dataset_logger.tinydb_evaluation_row_store import TinyDBEvaluationRowStore

return TinyDBEvaluationRowStore(db_path)


def _get_default_db_filename() -> str:
"""Get the default database filename based on storage backend."""
storage_type = os.getenv("EP_STORAGE", "tinydb").lower()
return "logs.db" if storage_type == "sqlite" else "logs.json"


def _get_default_logger():
"""Get the default logger based on configuration."""
# Allow disabling logger to avoid environment-specific constraints in simple CLI runs.
if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() == "1":

class _NoOpLogger(DatasetLogger):
def log(self, row):
Expand All @@ -19,6 +50,11 @@ def read(self, rollout_id=None):

return _NoOpLogger()

# Import here to avoid circular imports
from eval_protocol.dataset_logger.dataset_logger_adapter import DatasetLoggerAdapter

return DatasetLoggerAdapter()


# Lazy property that creates the logger only when accessed
class _LazyLogger(DatasetLogger):
Expand Down
61 changes: 61 additions & 0 deletions eval_protocol/dataset_logger/dataset_logger_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
from typing import TYPE_CHECKING, List, Optional

from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE, DatasetLogger
from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore
from eval_protocol.directory_utils import find_eval_protocol_dir
from eval_protocol.event_bus import event_bus
from eval_protocol.event_bus.logger import logger

if TYPE_CHECKING:
from eval_protocol.models import EvaluationRow


class DatasetLoggerAdapter(DatasetLogger):
"""
Dataset logger that uses the configured storage backend.

The storage backend is selected based on the EP_STORAGE environment variable:
- "tinydb" (default): Uses TinyDB with JSON file storage
- "sqlite": Uses SQLite with peewee ORM
"""

def __init__(self, db_path: Optional[str] = None, store: Optional[EvaluationRowStore] = None):
eval_protocol_dir = find_eval_protocol_dir()
if db_path is not None and store is not None:
raise ValueError("Provide only one of db_path or store, not both.")
if store is not None:
self.db_path = store.db_path
self._store = store
else:
# Import here to avoid circular imports
from eval_protocol.dataset_logger import _get_default_db_filename, get_evaluation_row_store

default_db = _get_default_db_filename()
self.db_path = db_path if db_path is not None else os.path.join(eval_protocol_dir, default_db)
self._store = get_evaluation_row_store(self.db_path)

def log(self, row: "EvaluationRow") -> None:
data = row.model_dump(exclude_none=True, mode="json")
rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown")
logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}")

self._store.upsert_row(data=data)
logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}")

try:
from eval_protocol.models import EvaluationRow as EvalRow

logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}")
event_bus.emit(LOG_EVENT_TYPE, EvalRow(**data))
logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}")
except Exception as e:
# Avoid breaking storage due to event emission issues
logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}")
pass

def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
from eval_protocol.models import EvaluationRow

results = self._store.read_rows(rollout_id=rollout_id)
return [EvaluationRow(**data) for data in results]
63 changes: 63 additions & 0 deletions eval_protocol/dataset_logger/evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from abc import ABC, abstractmethod
from typing import List, Optional


class EvaluationRowStore(ABC):
"""
Abstract base class for evaluation row storage.

Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
Implementations can use different storage backends (SQLite, TinyDB, etc.)
"""

@property
@abstractmethod
def db_path(self) -> str:
"""Return the path to the database file."""
pass

@abstractmethod
def upsert_row(self, data: dict) -> None:
"""
Insert or update a row by rollout_id.

Args:
data: Row data containing execution_metadata.rollout_id
"""
pass

@abstractmethod
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
"""
Read rows, optionally filtered by rollout_id.

Args:
rollout_id: If provided, filter to this specific rollout

Returns:
List of row data dictionaries
"""
pass

@abstractmethod
def delete_row(self, rollout_id: str) -> int:
"""
Delete a row by rollout_id.

Args:
rollout_id: The rollout_id to delete

Returns:
Number of rows deleted
"""
pass

@abstractmethod
def delete_all_rows(self) -> int:
"""
Delete all rows.

Returns:
Number of rows deleted
"""
pass
50 changes: 9 additions & 41 deletions eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,13 @@
import os
from typing import List, Optional
"""
Backwards-compatible alias for DatasetLoggerAdapter.

from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE, DatasetLogger
from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore
from eval_protocol.directory_utils import find_eval_protocol_dir
from eval_protocol.event_bus import event_bus
from eval_protocol.event_bus.logger import logger
from eval_protocol.models import EvaluationRow
This module is kept for backwards compatibility. New code should use
DatasetLoggerAdapter from dataset_logger_adapter.py instead.
"""

from eval_protocol.dataset_logger.dataset_logger_adapter import DatasetLoggerAdapter

class SqliteDatasetLoggerAdapter(DatasetLogger):
def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluationRowStore] = None):
eval_protocol_dir = find_eval_protocol_dir()
if db_path is not None and store is not None:
raise ValueError("Provide only one of db_path or store, not both.")
if store is not None:
self.db_path = store.db_path
self._store = store
else:
self.db_path = db_path if db_path is not None else os.path.join(eval_protocol_dir, "logs.db")
self._store = SqliteEvaluationRowStore(self.db_path)
# Backwards-compatible alias
SqliteDatasetLoggerAdapter = DatasetLoggerAdapter

def log(self, row: "EvaluationRow") -> None:
data = row.model_dump(exclude_none=True, mode="json")
rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown")
logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}")

self._store.upsert_row(data=data)
logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}")

try:
logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}")
event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data))
logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}")
except Exception as e:
# Avoid breaking storage due to event emission issues
logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}")
pass

def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
from eval_protocol.models import EvaluationRow

results = self._store.read_rows(rollout_id=rollout_id)
return [EvaluationRow(**data) for data in results]
__all__ = ["SqliteDatasetLoggerAdapter"]
18 changes: 13 additions & 5 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import os
from typing import List, Optional

from peewee import CharField, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField
try:
from peewee import CharField, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField
except ImportError:
raise ImportError(
"SQLite storage backend requires 'peewee' package. Install it with: pip install eval-protocol[sqlite_storage]"
)

from eval_protocol.models import EvaluationRow
from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore


class SqliteEvaluationRowStore:
class SqliteEvaluationRowStore(EvaluationRowStore):
"""
Lightweight reusable SQLite store for evaluation rows.

Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
"""

def __init__(self, db_path: str):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
# Handle case where db_path might be in the root directory
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db_path = db_path
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})

Expand Down
120 changes: 120 additions & 0 deletions eval_protocol/dataset_logger/tinydb_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json
import logging
import os
import time
from typing import List, Optional

from tinydb import Query, TinyDB
from tinyrecord.transaction import transaction

from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore

logger = logging.getLogger(__name__)


class TinyDBEvaluationRowStore(EvaluationRowStore):
"""
TinyDB-based evaluation row store.

Stores data as plain JSON files, which are human-readable and
don't suffer from SQLite's binary format corruption issues.

Uses tinyrecord for atomic transactions to handle concurrent access
from multiple processes safely.
"""

def __init__(self, db_path: str):
# Handle case where db_path might be in the root directory
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db_path = db_path
self._db = self._open_db_with_retry()
self._table = self._db.table("evaluation_rows")

def _open_db_with_retry(self, max_retries: int = 3) -> TinyDB:
"""Open TinyDB with retry logic to handle transient JSON decode errors."""
last_error: Exception | None = None
for attempt in range(max_retries):
try:
return TinyDB(self._db_path)
except json.JSONDecodeError as e:
last_error = e
logger.warning(f"TinyDB JSON decode error on attempt {attempt + 1}: {e}")
# Wait a bit and retry - the file might be mid-write
time.sleep(0.1 * (attempt + 1))
# Try to recover by removing the corrupted file
if attempt == max_retries - 1 and os.path.exists(self._db_path):
try:
logger.warning(f"Removing corrupted TinyDB file: {self._db_path}")
os.remove(self._db_path)
return TinyDB(self._db_path)
except Exception:
pass
raise last_error if last_error else RuntimeError("Failed to open TinyDB")

@property
def db_path(self) -> str:
return self._db_path

def upsert_row(self, data: dict) -> None:
rollout_id = data["execution_metadata"]["rollout_id"]
if rollout_id is None:
raise ValueError("execution_metadata.rollout_id is required to upsert a row")

Row = Query()
condition = Row.execution_metadata.rollout_id == rollout_id

# tinyrecord doesn't support upsert directly, so we implement it manually
# within a transaction for atomicity
with transaction(self._table) as tr:
# Clear cache to ensure fresh read in multi-process scenarios
self._table.clear_cache()
# Check if document exists
existing = self._table.search(condition)
if existing:
# Update existing document
tr.update(data, condition)
else:
# Insert new document
tr.insert(data)

def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
"""Read rows with retry logic for transient JSON decode errors."""
max_retries = 3
for attempt in range(max_retries):
try:
# Clear cache to ensure fresh read in multi-process scenarios
self._table.clear_cache()
if rollout_id is not None:
Row = Query()
return list(self._table.search(Row.execution_metadata.rollout_id == rollout_id))
return list(self._table.all())
except json.JSONDecodeError as e:
logger.warning(f"TinyDB JSON decode error on read attempt {attempt + 1}: {e}")
if attempt < max_retries - 1:
time.sleep(0.1 * (attempt + 1))
else:
# Return empty list on final failure rather than crash
logger.warning("Failed to read TinyDB after retries, returning empty list")
return []
return []

def delete_row(self, rollout_id: str) -> int:
Row = Query()
condition = Row.execution_metadata.rollout_id == rollout_id

with transaction(self._table) as tr:
# Clear cache to ensure fresh read in multi-process scenarios
self._table.clear_cache()
# Check if document exists before removal to get accurate count
existing = self._table.search(condition)
if existing:
tr.remove(condition)
return len(existing)
return 0

def delete_all_rows(self) -> int:
count = len(self._table)
self._table.truncate()
return count
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing cache clear before counting rows in delete_all

The delete_all_rows method calls len(self._table) without first calling self._table.clear_cache(). In multi-process scenarios, the cached table data may be stale, causing the returned count to be inaccurate. Other methods like upsert_row, read_rows, and delete_row correctly call clear_cache() before reading from the table, but this method is inconsistent.

Fix in Cursor Fix in Web

Loading
Loading