diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index 9233e8b7..f6a81e1e 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -1,13 +1,13 @@ import os from typing import List, Optional -from peewee import CharField, DatabaseError, Model, SqliteDatabase +from peewee import CharField, Model, SqliteDatabase from playhouse.sqlite_ext import JSONField from eval_protocol.event_bus.sqlite_event_bus_database import ( SQLITE_HARDENED_PRAGMAS, - DatabaseCorruptedError, check_and_repair_database, + execute_with_sqlite_retry, ) from eval_protocol.models import EvaluationRow @@ -55,7 +55,13 @@ def upsert_row(self, data: dict) -> None: if rollout_id is None: raise ValueError("execution_metadata.rollout_id is required to upsert a row") - with self._db.atomic("EXCLUSIVE"): + execute_with_sqlite_retry(lambda: self._do_upsert(rollout_id, data)) + + def _do_upsert(self, rollout_id: str, data: dict) -> None: + """Internal method to perform the actual upsert within a transaction.""" + # Use IMMEDIATE instead of EXCLUSIVE for better concurrency + # IMMEDIATE acquires a reserved lock immediately but allows concurrent reads + with self._db.atomic("IMMEDIATE"): if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists(): self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute() else: diff --git a/eval_protocol/event_bus/__init__.py b/eval_protocol/event_bus/__init__.py index b51ebac2..00e10c0b 100644 --- a/eval_protocol/event_bus/__init__.py +++ b/eval_protocol/event_bus/__init__.py @@ -4,6 +4,7 @@ from eval_protocol.event_bus.sqlite_event_bus_database import ( DatabaseCorruptedError, check_and_repair_database, + execute_with_sqlite_retry, SQLITE_HARDENED_PRAGMAS, ) diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index f148991d..5086d6e3 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -1,14 +1,60 @@ import os import time -from typing import Any, List +from typing import Any, Callable, List, TypeVar from uuid import uuid4 -from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, SqliteDatabase +import backoff +from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, OperationalError, SqliteDatabase from playhouse.sqlite_ext import JSONField from eval_protocol.event_bus.logger import logger +# Retry configuration for database operations +SQLITE_RETRY_MAX_TRIES = 5 +SQLITE_RETRY_MAX_TIME = 30 # seconds + + +def _is_database_locked_error(e: Exception) -> bool: + """Check if an exception is a database locked error.""" + error_str = str(e).lower() + return "database is locked" in error_str or "locked" in error_str + + +T = TypeVar("T") + + +def execute_with_sqlite_retry(operation: Callable[[], T]) -> T: + """ + Execute a database operation with exponential backoff retry on lock errors. + + Uses the backoff library for consistent retry behavior across the codebase. + Retries only on OperationalError with "database is locked" message. + + Args: + operation: A callable that performs the database operation + + Returns: + The result of the operation + + Raises: + OperationalError: If the operation fails after all retries + """ + + @backoff.on_exception( + backoff.expo, + OperationalError, + max_tries=SQLITE_RETRY_MAX_TRIES, + max_time=SQLITE_RETRY_MAX_TIME, + giveup=lambda e: not _is_database_locked_error(e), + jitter=backoff.full_jitter, + ) + def _execute() -> T: + return operation() + + return _execute() + + # SQLite pragmas for hardened concurrency safety SQLITE_HARDENED_PRAGMAS = { "journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes @@ -148,13 +194,15 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None: else: serialized_data = data - self._Event.create( - event_id=str(uuid4()), - event_type=event_type, - data=serialized_data, - timestamp=time.time(), - process_id=process_id, - processed=False, + execute_with_sqlite_retry( + lambda: self._Event.create( + event_id=str(uuid4()), + event_type=event_type, + data=serialized_data, + timestamp=time.time(), + process_id=process_id, + processed=False, + ) ) except Exception as e: logger.warning(f"Failed to publish event to database: {e}") @@ -188,7 +236,9 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]: def mark_event_processed(self, event_id: str) -> None: """Mark an event as processed.""" try: - self._Event.update(processed=True).where(self._Event.event_id == event_id).execute() + execute_with_sqlite_retry( + lambda: self._Event.update(processed=True).where(self._Event.event_id == event_id).execute() + ) except Exception as e: logger.debug(f"Failed to mark event as processed: {e}") @@ -196,6 +246,10 @@ def cleanup_old_events(self, max_age_hours: int = 24) -> None: """Clean up old processed events.""" try: cutoff_time = time.time() - (max_age_hours * 3600) - self._Event.delete().where((self._Event.processed) & (self._Event.timestamp < cutoff_time)).execute() + execute_with_sqlite_retry( + lambda: self._Event.delete() + .where((self._Event.processed) & (self._Event.timestamp < cutoff_time)) + .execute() + ) except Exception as e: logger.debug(f"Failed to cleanup old events: {e}")