-
Notifications
You must be signed in to change notification settings - Fork 16
Replace SQlite with TinyDB #366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9b5e353
3ffc83a
6c76ad2
fd113c4
4ae1242
2afaa53
46da8d3
11221fb
ae7a2d3
dbab765
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| import os | ||
| from typing import TYPE_CHECKING, List, Optional | ||
|
|
||
| from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE, DatasetLogger | ||
| from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore | ||
| from eval_protocol.directory_utils import find_eval_protocol_dir | ||
| from eval_protocol.event_bus import event_bus | ||
| from eval_protocol.event_bus.logger import logger | ||
|
|
||
| if TYPE_CHECKING: | ||
| from eval_protocol.models import EvaluationRow | ||
|
|
||
|
|
||
| class DatasetLoggerAdapter(DatasetLogger): | ||
| """ | ||
| Dataset logger that uses the configured storage backend. | ||
|
|
||
| The storage backend is selected based on the EP_STORAGE environment variable: | ||
| - "tinydb" (default): Uses TinyDB with JSON file storage | ||
| - "sqlite": Uses SQLite with peewee ORM | ||
| """ | ||
|
|
||
| def __init__(self, db_path: Optional[str] = None, store: Optional[EvaluationRowStore] = None): | ||
| eval_protocol_dir = find_eval_protocol_dir() | ||
| if db_path is not None and store is not None: | ||
| raise ValueError("Provide only one of db_path or store, not both.") | ||
| if store is not None: | ||
| self.db_path = store.db_path | ||
| self._store = store | ||
| else: | ||
| # Import here to avoid circular imports | ||
| from eval_protocol.dataset_logger import _get_default_db_filename, get_evaluation_row_store | ||
|
|
||
| default_db = _get_default_db_filename() | ||
| self.db_path = db_path if db_path is not None else os.path.join(eval_protocol_dir, default_db) | ||
| self._store = get_evaluation_row_store(self.db_path) | ||
|
|
||
| def log(self, row: "EvaluationRow") -> None: | ||
| data = row.model_dump(exclude_none=True, mode="json") | ||
| rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown") | ||
| logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}") | ||
|
|
||
| self._store.upsert_row(data=data) | ||
| logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}") | ||
|
|
||
| try: | ||
| from eval_protocol.models import EvaluationRow as EvalRow | ||
|
|
||
| logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}") | ||
| event_bus.emit(LOG_EVENT_TYPE, EvalRow(**data)) | ||
| logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}") | ||
| except Exception as e: | ||
| # Avoid breaking storage due to event emission issues | ||
| logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}") | ||
| pass | ||
|
|
||
| def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]: | ||
| from eval_protocol.models import EvaluationRow | ||
|
|
||
| results = self._store.read_rows(rollout_id=rollout_id) | ||
| return [EvaluationRow(**data) for data in results] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import List, Optional | ||
|
|
||
|
|
||
| class EvaluationRowStore(ABC): | ||
| """ | ||
| Abstract base class for evaluation row storage. | ||
|
|
||
| Stores arbitrary row data as JSON keyed by a unique string `rollout_id`. | ||
| Implementations can use different storage backends (SQLite, TinyDB, etc.) | ||
| """ | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def db_path(self) -> str: | ||
| """Return the path to the database file.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def upsert_row(self, data: dict) -> None: | ||
| """ | ||
| Insert or update a row by rollout_id. | ||
|
|
||
| Args: | ||
| data: Row data containing execution_metadata.rollout_id | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]: | ||
| """ | ||
| Read rows, optionally filtered by rollout_id. | ||
|
|
||
| Args: | ||
| rollout_id: If provided, filter to this specific rollout | ||
|
|
||
| Returns: | ||
| List of row data dictionaries | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def delete_row(self, rollout_id: str) -> int: | ||
| """ | ||
| Delete a row by rollout_id. | ||
|
|
||
| Args: | ||
| rollout_id: The rollout_id to delete | ||
|
|
||
| Returns: | ||
| Number of rows deleted | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def delete_all_rows(self) -> int: | ||
| """ | ||
| Delete all rows. | ||
|
|
||
| Returns: | ||
| Number of rows deleted | ||
| """ | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,45 +1,13 @@ | ||
| import os | ||
| from typing import List, Optional | ||
| """ | ||
| Backwards-compatible alias for DatasetLoggerAdapter. | ||
|
|
||
| from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE, DatasetLogger | ||
| from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore | ||
| from eval_protocol.directory_utils import find_eval_protocol_dir | ||
| from eval_protocol.event_bus import event_bus | ||
| from eval_protocol.event_bus.logger import logger | ||
| from eval_protocol.models import EvaluationRow | ||
| This module is kept for backwards compatibility. New code should use | ||
| DatasetLoggerAdapter from dataset_logger_adapter.py instead. | ||
| """ | ||
|
|
||
| from eval_protocol.dataset_logger.dataset_logger_adapter import DatasetLoggerAdapter | ||
|
|
||
| class SqliteDatasetLoggerAdapter(DatasetLogger): | ||
| def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluationRowStore] = None): | ||
| eval_protocol_dir = find_eval_protocol_dir() | ||
| if db_path is not None and store is not None: | ||
| raise ValueError("Provide only one of db_path or store, not both.") | ||
| if store is not None: | ||
| self.db_path = store.db_path | ||
| self._store = store | ||
| else: | ||
| self.db_path = db_path if db_path is not None else os.path.join(eval_protocol_dir, "logs.db") | ||
| self._store = SqliteEvaluationRowStore(self.db_path) | ||
| # Backwards-compatible alias | ||
| SqliteDatasetLoggerAdapter = DatasetLoggerAdapter | ||
|
|
||
| def log(self, row: "EvaluationRow") -> None: | ||
| data = row.model_dump(exclude_none=True, mode="json") | ||
| rollout_id = data.get("execution_metadata", {}).get("rollout_id", "unknown") | ||
| logger.debug(f"[EVENT_BUS_EMIT] Starting to log row with rollout_id: {rollout_id}") | ||
|
|
||
| self._store.upsert_row(data=data) | ||
| logger.debug(f"[EVENT_BUS_EMIT] Successfully stored row in database for rollout_id: {rollout_id}") | ||
|
|
||
| try: | ||
| logger.debug(f"[EVENT_BUS_EMIT] Emitting event '{LOG_EVENT_TYPE}' for rollout_id: {rollout_id}") | ||
| event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data)) | ||
| logger.debug(f"[EVENT_BUS_EMIT] Successfully emitted event for rollout_id: {rollout_id}") | ||
| except Exception as e: | ||
| # Avoid breaking storage due to event emission issues | ||
| logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}") | ||
| pass | ||
|
|
||
| def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]: | ||
| from eval_protocol.models import EvaluationRow | ||
|
|
||
| results = self._store.read_rows(rollout_id=rollout_id) | ||
| return [EvaluationRow(**data) for data in results] | ||
| __all__ = ["SqliteDatasetLoggerAdapter"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| import json | ||
| import logging | ||
| import os | ||
| import time | ||
| from typing import List, Optional | ||
|
|
||
| from tinydb import Query, TinyDB | ||
| from tinyrecord.transaction import transaction | ||
|
|
||
| from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class TinyDBEvaluationRowStore(EvaluationRowStore): | ||
| """ | ||
| TinyDB-based evaluation row store. | ||
|
|
||
| Stores data as plain JSON files, which are human-readable and | ||
| don't suffer from SQLite's binary format corruption issues. | ||
|
|
||
| Uses tinyrecord for atomic transactions to handle concurrent access | ||
| from multiple processes safely. | ||
| """ | ||
|
|
||
| def __init__(self, db_path: str): | ||
| # Handle case where db_path might be in the root directory | ||
| db_dir = os.path.dirname(db_path) | ||
| if db_dir: | ||
| os.makedirs(db_dir, exist_ok=True) | ||
| self._db_path = db_path | ||
| self._db = self._open_db_with_retry() | ||
| self._table = self._db.table("evaluation_rows") | ||
|
|
||
| def _open_db_with_retry(self, max_retries: int = 3) -> TinyDB: | ||
| """Open TinyDB with retry logic to handle transient JSON decode errors.""" | ||
| last_error: Exception | None = None | ||
| for attempt in range(max_retries): | ||
| try: | ||
| return TinyDB(self._db_path) | ||
| except json.JSONDecodeError as e: | ||
| last_error = e | ||
| logger.warning(f"TinyDB JSON decode error on attempt {attempt + 1}: {e}") | ||
| # Wait a bit and retry - the file might be mid-write | ||
| time.sleep(0.1 * (attempt + 1)) | ||
| # Try to recover by removing the corrupted file | ||
| if attempt == max_retries - 1 and os.path.exists(self._db_path): | ||
| try: | ||
| logger.warning(f"Removing corrupted TinyDB file: {self._db_path}") | ||
| os.remove(self._db_path) | ||
| return TinyDB(self._db_path) | ||
| except Exception: | ||
| pass | ||
| raise last_error if last_error else RuntimeError("Failed to open TinyDB") | ||
|
|
||
| @property | ||
| def db_path(self) -> str: | ||
| return self._db_path | ||
|
|
||
| def upsert_row(self, data: dict) -> None: | ||
| rollout_id = data["execution_metadata"]["rollout_id"] | ||
| if rollout_id is None: | ||
| raise ValueError("execution_metadata.rollout_id is required to upsert a row") | ||
|
|
||
| Row = Query() | ||
| condition = Row.execution_metadata.rollout_id == rollout_id | ||
|
|
||
| # tinyrecord doesn't support upsert directly, so we implement it manually | ||
| # within a transaction for atomicity | ||
| with transaction(self._table) as tr: | ||
| # Clear cache to ensure fresh read in multi-process scenarios | ||
| self._table.clear_cache() | ||
| # Check if document exists | ||
| existing = self._table.search(condition) | ||
dphuang2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if existing: | ||
| # Update existing document | ||
| tr.update(data, condition) | ||
| else: | ||
| # Insert new document | ||
| tr.insert(data) | ||
|
|
||
| def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]: | ||
| """Read rows with retry logic for transient JSON decode errors.""" | ||
| max_retries = 3 | ||
| for attempt in range(max_retries): | ||
| try: | ||
| # Clear cache to ensure fresh read in multi-process scenarios | ||
| self._table.clear_cache() | ||
| if rollout_id is not None: | ||
| Row = Query() | ||
| return list(self._table.search(Row.execution_metadata.rollout_id == rollout_id)) | ||
| return list(self._table.all()) | ||
| except json.JSONDecodeError as e: | ||
| logger.warning(f"TinyDB JSON decode error on read attempt {attempt + 1}: {e}") | ||
| if attempt < max_retries - 1: | ||
| time.sleep(0.1 * (attempt + 1)) | ||
| else: | ||
| # Return empty list on final failure rather than crash | ||
| logger.warning("Failed to read TinyDB after retries, returning empty list") | ||
| return [] | ||
| return [] | ||
|
|
||
| def delete_row(self, rollout_id: str) -> int: | ||
| Row = Query() | ||
| condition = Row.execution_metadata.rollout_id == rollout_id | ||
|
|
||
| with transaction(self._table) as tr: | ||
| # Clear cache to ensure fresh read in multi-process scenarios | ||
| self._table.clear_cache() | ||
| # Check if document exists before removal to get accurate count | ||
| existing = self._table.search(condition) | ||
| if existing: | ||
| tr.remove(condition) | ||
| return len(existing) | ||
| return 0 | ||
|
|
||
| def delete_all_rows(self) -> int: | ||
| count = len(self._table) | ||
| self._table.truncate() | ||
| return count | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Missing cache clear before counting rows in delete_allThe |
||
Uh oh!
There was an error while loading. Please reload this page.