|
1 | 1 | import os |
2 | 2 | import time |
3 | | -from typing import Any, List |
| 3 | +from typing import Any, Callable, List, TypeVar |
4 | 4 | from uuid import uuid4 |
5 | 5 |
|
6 | | -from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, SqliteDatabase |
| 6 | +import backoff |
| 7 | +from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, OperationalError, SqliteDatabase |
7 | 8 | from playhouse.sqlite_ext import JSONField |
8 | 9 |
|
9 | 10 | from eval_protocol.event_bus.logger import logger |
10 | 11 |
|
11 | 12 |
|
| 13 | +# Retry configuration for database operations |
| 14 | +SQLITE_RETRY_MAX_TRIES = 5 |
| 15 | +SQLITE_RETRY_MAX_TIME = 30 # seconds |
| 16 | + |
| 17 | + |
| 18 | +def _is_database_locked_error(e: Exception) -> bool: |
| 19 | + """Check if an exception is a database locked error.""" |
| 20 | + error_str = str(e).lower() |
| 21 | + return "database is locked" in error_str or "locked" in error_str |
| 22 | + |
| 23 | + |
| 24 | +T = TypeVar("T") |
| 25 | + |
| 26 | + |
| 27 | +def execute_with_sqlite_retry(operation: Callable[[], T]) -> T: |
| 28 | + """ |
| 29 | + Execute a database operation with exponential backoff retry on lock errors. |
| 30 | +
|
| 31 | + Uses the backoff library for consistent retry behavior across the codebase. |
| 32 | + Retries only on OperationalError with "database is locked" message. |
| 33 | +
|
| 34 | + Args: |
| 35 | + operation: A callable that performs the database operation |
| 36 | +
|
| 37 | + Returns: |
| 38 | + The result of the operation |
| 39 | +
|
| 40 | + Raises: |
| 41 | + OperationalError: If the operation fails after all retries |
| 42 | + """ |
| 43 | + |
| 44 | + @backoff.on_exception( |
| 45 | + backoff.expo, |
| 46 | + OperationalError, |
| 47 | + max_tries=SQLITE_RETRY_MAX_TRIES, |
| 48 | + max_time=SQLITE_RETRY_MAX_TIME, |
| 49 | + giveup=lambda e: not _is_database_locked_error(e), |
| 50 | + jitter=backoff.full_jitter, |
| 51 | + ) |
| 52 | + def _execute() -> T: |
| 53 | + return operation() |
| 54 | + |
| 55 | + return _execute() |
| 56 | + |
| 57 | + |
12 | 58 | # SQLite pragmas for hardened concurrency safety |
13 | 59 | SQLITE_HARDENED_PRAGMAS = { |
14 | 60 | "journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes |
@@ -148,13 +194,15 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None: |
148 | 194 | else: |
149 | 195 | serialized_data = data |
150 | 196 |
|
151 | | - self._Event.create( |
152 | | - event_id=str(uuid4()), |
153 | | - event_type=event_type, |
154 | | - data=serialized_data, |
155 | | - timestamp=time.time(), |
156 | | - process_id=process_id, |
157 | | - processed=False, |
| 197 | + execute_with_sqlite_retry( |
| 198 | + lambda: self._Event.create( |
| 199 | + event_id=str(uuid4()), |
| 200 | + event_type=event_type, |
| 201 | + data=serialized_data, |
| 202 | + timestamp=time.time(), |
| 203 | + process_id=process_id, |
| 204 | + processed=False, |
| 205 | + ) |
158 | 206 | ) |
159 | 207 | except Exception as e: |
160 | 208 | logger.warning(f"Failed to publish event to database: {e}") |
@@ -188,14 +236,20 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]: |
188 | 236 | def mark_event_processed(self, event_id: str) -> None: |
189 | 237 | """Mark an event as processed.""" |
190 | 238 | try: |
191 | | - self._Event.update(processed=True).where(self._Event.event_id == event_id).execute() |
| 239 | + execute_with_sqlite_retry( |
| 240 | + lambda: self._Event.update(processed=True).where(self._Event.event_id == event_id).execute() |
| 241 | + ) |
192 | 242 | except Exception as e: |
193 | 243 | logger.debug(f"Failed to mark event as processed: {e}") |
194 | 244 |
|
195 | 245 | def cleanup_old_events(self, max_age_hours: int = 24) -> None: |
196 | 246 | """Clean up old processed events.""" |
197 | 247 | try: |
198 | 248 | cutoff_time = time.time() - (max_age_hours * 3600) |
199 | | - self._Event.delete().where((self._Event.processed) & (self._Event.timestamp < cutoff_time)).execute() |
| 249 | + execute_with_sqlite_retry( |
| 250 | + lambda: self._Event.delete() |
| 251 | + .where((self._Event.processed) & (self._Event.timestamp < cutoff_time)) |
| 252 | + .execute() |
| 253 | + ) |
200 | 254 | except Exception as e: |
201 | 255 | logger.debug(f"Failed to cleanup old events: {e}") |
0 commit comments