From 50ea6db3b32a92ebdee728b08c9088984d3a937c Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Sat, 13 Dec 2025 01:18:05 -0800 Subject: [PATCH 1/4] add tests to pyright --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a43f773a..60ed1346 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,8 +217,8 @@ combine-as-imports = true [tool.pyright] typeCheckingMode = "basic" # Changed from "standard" to reduce memory usage pythonVersion = "3.10" -include = ["eval_protocol"] # Reduced scope to just the main package -exclude = ["vite-app", "vendor", "examples", "tests", "development", "local_evals"] +include = ["eval_protocol", "tests"] # Reduced scope to just the main package +exclude = ["vite-app", "vendor", "examples", "development", "local_evals"] # Ignore diagnostics for vendored generator code ignore = ["versioneer.py"] reportUnusedCallResult = "none" From 64025a4d73bcec8acff447bc0386c20df80ee521 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Sat, 13 Dec 2025 01:22:12 -0800 Subject: [PATCH 2/4] sqlite hardening --- eval_protocol/agent/resources/sql_resource.py | 27 +- eval_protocol/cli_commands/logs.py | 120 ++++- .../sqlite_evaluation_row_store.py | 21 +- eval_protocol/event_bus/__init__.py | 5 + .../event_bus/sqlite_event_bus_database.py | 115 ++++- tests/test_sqlite_hardening.py | 417 ++++++++++++++++++ 6 files changed, 680 insertions(+), 25 deletions(-) create mode 100644 tests/test_sqlite_hardening.py diff --git a/eval_protocol/agent/resources/sql_resource.py b/eval_protocol/agent/resources/sql_resource.py index c98582b7..ed8c0360 100644 --- a/eval_protocol/agent/resources/sql_resource.py +++ b/eval_protocol/agent/resources/sql_resource.py @@ -12,6 +12,21 @@ from ..resource_abc import ForkableResource +# SQLite connection settings for hardened concurrency safety +SQLITE_CONNECTION_TIMEOUT = 30 # 30 seconds + + +def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None: + """Apply hardened SQLite pragmas for concurrency safety.""" + conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging + conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and performance + conn.execute("PRAGMA busy_timeout=30000") # 30 second timeout + conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint every 1000 pages + conn.execute("PRAGMA cache_size=-64000") # 64MB cache + conn.execute("PRAGMA foreign_keys=ON") # Enable foreign key constraints + conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory + + class SQLResource(ForkableResource): """ A ForkableResource for managing SQL database states, primarily SQLite. @@ -20,6 +35,8 @@ class SQLResource(ForkableResource): and seed data, forked (by copying the DB file), checkpointed (by copying), and restored. + Uses hardened SQLite settings for concurrency safety. + Attributes: _config (Dict[str, Any]): Configuration for the resource. _db_path (Optional[Path]): Path to the current SQLite database file. @@ -38,8 +55,14 @@ def __init__(self) -> None: def _get_db_connection(self) -> sqlite3.Connection: if not self._db_path: raise ConnectionError("Database path not set. Call setup() or fork() first.") - # Set timeout to prevent indefinite hangs - return sqlite3.connect(str(self._db_path), timeout=10) + # Set timeout to prevent indefinite hangs with hardened settings + conn = sqlite3.connect( + str(self._db_path), + timeout=SQLITE_CONNECTION_TIMEOUT, + isolation_level="DEFERRED", # Better for concurrent access + ) + _apply_hardened_pragmas(conn) + return conn async def setup(self, config: Dict[str, Any]) -> None: """ diff --git a/eval_protocol/cli_commands/logs.py b/eval_protocol/cli_commands/logs.py index 89929a82..a952b789 100644 --- a/eval_protocol/cli_commands/logs.py +++ b/eval_protocol/cli_commands/logs.py @@ -7,6 +7,82 @@ import os from ..utils.logs_server import serve_logs +from ..event_bus.sqlite_event_bus_database import DatabaseCorruptedError, _backup_and_remove_database + + +def _handle_database_corruption(db_path: str) -> bool: + """ + Handle database corruption by prompting user to fix it. + + Args: + db_path: Path to the corrupted database + + Returns: + True if user chose to fix and database was reset, False otherwise + """ + print("\n" + "=" * 60) + print("⚠️ DATABASE CORRUPTION DETECTED") + print("=" * 60) + print(f"\nThe database file at:\n {db_path}\n") + print("appears to be corrupted or is not a valid SQLite database.") + print("\nThis can happen due to:") + print(" • Incomplete writes during a crash") + print(" • Concurrent access issues") + print(" • File system errors") + print("\n" + "-" * 60) + print("Would you like to automatically fix this?") + print(" • The corrupted file will be backed up") + print(" • A fresh database will be created") + print(" • You will lose existing log data, but can continue using the tool") + print("-" * 60) + + try: + response = input("\nFix database automatically? [Y/n]: ").strip().lower() + if response in ("", "y", "yes"): + _backup_and_remove_database(db_path) + print("\n✅ Database has been reset. Restarting server...") + return True + else: + print("\n❌ Database repair cancelled.") + print(f" You can manually delete the corrupted file: {db_path}") + return False + except (EOFError, KeyboardInterrupt): + print("\n❌ Database repair cancelled.") + return False + + +def _is_database_corruption_error(error: Exception) -> tuple[bool, str]: + """ + Check if an exception is related to database corruption. + + Returns: + Tuple of (is_corruption_error, db_path) + """ + error_str = str(error).lower() + corruption_indicators = [ + "file is not a database", + "database disk image is malformed", + "database is locked", + "unable to open database file", + ] + + for indicator in corruption_indicators: + if indicator in error_str: + # Try to find the database path + from ..directory_utils import find_eval_protocol_dir + + try: + eval_protocol_dir = find_eval_protocol_dir() + db_path = os.path.join(eval_protocol_dir, "logs.db") + return True, db_path + except Exception: + return True, "" + + # Check if it's a DatabaseCorruptedError + if isinstance(error, DatabaseCorruptedError): + return True, error.db_path + + return False, "" def logs_command(args): @@ -40,18 +116,32 @@ def logs_command(args): or "https://tracing.fireworks.ai" ) - try: - serve_logs( - port=args.port, - elasticsearch_config=elasticsearch_config, - debug=args.debug, - backend="fireworks" if use_fireworks else "elasticsearch", - fireworks_base_url=fireworks_base_url if use_fireworks else None, - ) - return 0 - except KeyboardInterrupt: - print("\n🛑 Server stopped by user") - return 0 - except Exception as e: - print(f"❌ Error starting server: {e}") - return 1 + max_retries = 2 + for attempt in range(max_retries): + try: + serve_logs( + port=args.port, + elasticsearch_config=elasticsearch_config, + debug=args.debug, + backend="fireworks" if use_fireworks else "elasticsearch", + fireworks_base_url=fireworks_base_url if use_fireworks else None, + ) + return 0 + except KeyboardInterrupt: + print("\n🛑 Server stopped by user") + return 0 + except Exception as e: + is_corruption, db_path = _is_database_corruption_error(e) + + if is_corruption and db_path and attempt < max_retries - 1: + if _handle_database_corruption(db_path): + # User chose to fix, retry + continue + else: + # User declined fix + return 1 + + print(f"❌ Error starting server: {e}") + return 1 + + return 1 diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index a8e7b229..9233e8b7 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -1,9 +1,14 @@ import os from typing import List, Optional -from peewee import CharField, Model, SqliteDatabase +from peewee import CharField, DatabaseError, Model, SqliteDatabase from playhouse.sqlite_ext import JSONField +from eval_protocol.event_bus.sqlite_event_bus_database import ( + SQLITE_HARDENED_PRAGMAS, + DatabaseCorruptedError, + check_and_repair_database, +) from eval_protocol.models import EvaluationRow @@ -12,12 +17,20 @@ class SqliteEvaluationRowStore: Lightweight reusable SQLite store for evaluation rows. Stores arbitrary row data as JSON keyed by a unique string `rollout_id`. + Uses hardened SQLite settings for concurrency safety. """ - def __init__(self, db_path: str): - os.makedirs(os.path.dirname(db_path), exist_ok=True) + def __init__(self, db_path: str, auto_repair: bool = True): + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) self._db_path = db_path - self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"}) + + # Check and optionally repair corrupted database + check_and_repair_database(db_path, auto_repair=auto_repair) + + # Use hardened pragmas for concurrency safety + self._db = SqliteDatabase(self._db_path, pragmas=SQLITE_HARDENED_PRAGMAS) class BaseModel(Model): class Meta: diff --git a/eval_protocol/event_bus/__init__.py b/eval_protocol/event_bus/__init__.py index 86e572a9..b51ebac2 100644 --- a/eval_protocol/event_bus/__init__.py +++ b/eval_protocol/event_bus/__init__.py @@ -1,6 +1,11 @@ # Global event bus instance - uses SqliteEventBus for cross-process functionality from typing import Any, Callable from eval_protocol.event_bus.event_bus import EventBus +from eval_protocol.event_bus.sqlite_event_bus_database import ( + DatabaseCorruptedError, + check_and_repair_database, + SQLITE_HARDENED_PRAGMAS, +) def _get_default_event_bus(): diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 5d1f522a..f7d9bb5b 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -1,19 +1,125 @@ +import os import time from typing import Any, List from uuid import uuid4 -from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase +from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, SqliteDatabase from playhouse.sqlite_ext import JSONField from eval_protocol.event_bus.logger import logger +# SQLite pragmas for hardened concurrency safety +SQLITE_HARDENED_PRAGMAS = { + "journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes + "synchronous": "normal", # Balance between safety and performance + "busy_timeout": 30000, # 30 second timeout for locked database + "wal_autocheckpoint": 1000, # Checkpoint every 1000 pages + "cache_size": -64000, # 64MB cache (negative = KB) + "foreign_keys": 1, # Enable foreign key constraints + "temp_store": "memory", # Store temp tables in memory +} + + +class DatabaseCorruptedError(Exception): + """Raised when the database file is corrupted or not a valid SQLite database.""" + + def __init__(self, db_path: str, original_error: Exception): + self.db_path = db_path + self.original_error = original_error + super().__init__(f"Database file is corrupted: {db_path}. Original error: {original_error}") + + +def check_and_repair_database(db_path: str, auto_repair: bool = False) -> bool: + """ + Check if a database file is valid and optionally repair it. + + Args: + db_path: Path to the database file + auto_repair: If True, automatically delete and recreate corrupted database + + Returns: + True if database is valid or was repaired, False otherwise + + Raises: + DatabaseCorruptedError: If database is corrupted and auto_repair is False + """ + if not os.path.exists(db_path): + return True # New database, nothing to check + + try: + # Try to open the database and run an integrity check + test_db = SqliteDatabase(db_path, pragmas={"busy_timeout": 5000}) + test_db.connect() + cursor = test_db.execute_sql("PRAGMA integrity_check") + result = cursor.fetchone() + test_db.close() + + if result and result[0] == "ok": + return True + else: + logger.warning(f"Database integrity check failed for {db_path}: {result}") + if auto_repair: + _backup_and_remove_database(db_path) + return True + raise DatabaseCorruptedError(db_path, Exception(f"Integrity check failed: {result}")) + + except DatabaseError as e: + error_str = str(e).lower() + if "file is not a database" in error_str or "database disk image is malformed" in error_str: + logger.warning(f"Database file is corrupted: {db_path}") + if auto_repair: + _backup_and_remove_database(db_path) + return True + raise DatabaseCorruptedError(db_path, e) + raise + except Exception as e: + logger.warning(f"Error checking database {db_path}: {e}") + if auto_repair: + _backup_and_remove_database(db_path) + return True + raise DatabaseCorruptedError(db_path, e) + + +def _backup_and_remove_database(db_path: str) -> None: + """Backup a corrupted database file and remove it.""" + backup_path = f"{db_path}.corrupted.{int(time.time())}" + try: + os.rename(db_path, backup_path) + logger.info(f"Backed up corrupted database to: {backup_path}") + except OSError as e: + logger.warning(f"Failed to backup corrupted database, removing: {e}") + try: + os.remove(db_path) + except OSError: + pass + + # Also try to remove WAL and SHM files if they exist + for suffix in ["-wal", "-shm"]: + wal_file = f"{db_path}{suffix}" + if os.path.exists(wal_file): + try: + os.remove(wal_file) + except OSError: + pass + + class SqliteEventBusDatabase: """SQLite database for cross-process event communication.""" - def __init__(self, db_path: str): + def __init__(self, db_path: str, auto_repair: bool = True): self._db_path = db_path - self._db = SqliteDatabase(db_path) + + # Ensure directory exists + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + + # Check and optionally repair corrupted database + check_and_repair_database(db_path, auto_repair=auto_repair) + + # Initialize database with hardened concurrency settings + self._db = SqliteDatabase(db_path, pragmas=SQLITE_HARDENED_PRAGMAS) class BaseModel(Model): class Meta: @@ -29,7 +135,8 @@ class Event(BaseModel): # type: ignore self._Event = Event self._db.connect() - self._db.create_tables([Event]) + # Use safe=True to avoid errors when tables already exist + self._db.create_tables([Event], safe=True) def publish_event(self, event_type: str, data: Any, process_id: str) -> None: """Publish an event to the database.""" diff --git a/tests/test_sqlite_hardening.py b/tests/test_sqlite_hardening.py new file mode 100644 index 00000000..301240ff --- /dev/null +++ b/tests/test_sqlite_hardening.py @@ -0,0 +1,417 @@ +""" +Tests for SQLite hardening and concurrency safety. + +These tests verify that: +1. WAL mode and other concurrency pragmas are correctly applied +2. Database corruption detection works +3. Auto-repair functionality works +4. Multiple concurrent operations don't corrupt the database +""" + +import os +import sqlite3 +import tempfile +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import List + +import pytest + +from eval_protocol.event_bus.sqlite_event_bus_database import ( + SQLITE_HARDENED_PRAGMAS, + DatabaseCorruptedError, + SqliteEventBusDatabase, + _backup_and_remove_database, + check_and_repair_database, +) +from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore + + +class TestSqliteHardenedPragmas: + """Test that hardened pragmas are correctly defined and applied.""" + + def test_pragmas_are_defined(self): + """Test that all required pragmas are defined.""" + required_pragmas = [ + "journal_mode", + "synchronous", + "busy_timeout", + "wal_autocheckpoint", + "cache_size", + "foreign_keys", + "temp_store", + ] + for pragma in required_pragmas: + assert pragma in SQLITE_HARDENED_PRAGMAS, f"Missing pragma: {pragma}" + + def test_wal_mode_is_enabled(self): + """Test that WAL mode is set in pragmas.""" + assert SQLITE_HARDENED_PRAGMAS["journal_mode"] == "wal" + + def test_busy_timeout_is_set(self): + """Test that busy_timeout is set to a reasonable value.""" + # Should be at least 10 seconds (10000ms) + assert SQLITE_HARDENED_PRAGMAS["busy_timeout"] >= 10000 + + +class TestSqliteEventBusDatabaseHardening: + """Test SqliteEventBusDatabase hardening features.""" + + def test_creates_database_with_wal_mode(self): + """Test that database is created with WAL journal mode.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + db = SqliteEventBusDatabase(db_path) + + cursor = db._db.execute_sql("PRAGMA journal_mode") + journal_mode = cursor.fetchone()[0] + assert journal_mode == "wal", f"Expected WAL mode, got {journal_mode}" + + def test_creates_database_with_busy_timeout(self): + """Test that database has busy_timeout set.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + db = SqliteEventBusDatabase(db_path) + + cursor = db._db.execute_sql("PRAGMA busy_timeout") + timeout = cursor.fetchone()[0] + assert timeout == SQLITE_HARDENED_PRAGMAS["busy_timeout"] + + def test_creates_database_with_synchronous_normal(self): + """Test that synchronous mode is set to normal.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + db = SqliteEventBusDatabase(db_path) + + cursor = db._db.execute_sql("PRAGMA synchronous") + sync_mode = cursor.fetchone()[0] + # 1 = NORMAL in SQLite + assert sync_mode == 1, f"Expected synchronous=1 (NORMAL), got {sync_mode}" + + def test_creates_directory_if_not_exists(self): + """Test that parent directory is created if it doesn't exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "subdir", "nested", "test.db") + db = SqliteEventBusDatabase(db_path) + + assert os.path.exists(db_path) + assert os.path.isfile(db_path) + + +class TestSqliteEvaluationRowStoreHardening: + """Test SqliteEvaluationRowStore hardening features.""" + + def test_creates_database_with_wal_mode(self): + """Test that database is created with WAL journal mode.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "eval.db") + store = SqliteEvaluationRowStore(db_path) + + cursor = store._db.execute_sql("PRAGMA journal_mode") + journal_mode = cursor.fetchone()[0] + assert journal_mode == "wal", f"Expected WAL mode, got {journal_mode}" + + def test_creates_database_with_busy_timeout(self): + """Test that database has busy_timeout set.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "eval.db") + store = SqliteEvaluationRowStore(db_path) + + cursor = store._db.execute_sql("PRAGMA busy_timeout") + timeout = cursor.fetchone()[0] + assert timeout == SQLITE_HARDENED_PRAGMAS["busy_timeout"] + + +class TestDatabaseCorruptionDetection: + """Test database corruption detection functionality.""" + + def test_nonexistent_database_passes_check(self): + """Test that check passes for a non-existent database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "nonexistent.db") + result = check_and_repair_database(db_path) + assert result is True + + def test_valid_database_passes_check(self): + """Test that a valid database passes the integrity check.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "valid.db") + + # Create a valid database + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, data TEXT)") + conn.execute("INSERT INTO test VALUES (1, 'test')") + conn.commit() + conn.close() + + result = check_and_repair_database(db_path) + assert result is True + + def test_corrupted_file_raises_error_without_auto_repair(self): + """Test that a corrupted file raises DatabaseCorruptedError when auto_repair=False.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "corrupted.db") + + # Create a corrupted file (not a valid SQLite database) + with open(db_path, "w") as f: + f.write("This is not a valid SQLite database!") + + with pytest.raises(DatabaseCorruptedError) as exc_info: + check_and_repair_database(db_path, auto_repair=False) + + assert exc_info.value.db_path == db_path + + def test_corrupted_file_auto_repaired(self): + """Test that a corrupted file is auto-repaired when auto_repair=True.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "corrupted.db") + + # Create a corrupted file + with open(db_path, "w") as f: + f.write("This is not a valid SQLite database!") + + result = check_and_repair_database(db_path, auto_repair=True) + assert result is True + + # Original file should be removed (or renamed to backup) + assert not os.path.exists(db_path) + + def test_corrupted_file_backup_created(self): + """Test that a backup is created when auto-repairing a corrupted file.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "corrupted.db") + + # Create a corrupted file + with open(db_path, "w") as f: + f.write("This is not a valid SQLite database!") + + check_and_repair_database(db_path, auto_repair=True) + + # Check for backup file + files = os.listdir(tmpdir) + backup_files = [f for f in files if "corrupted" in f and f != "corrupted.db"] + assert len(backup_files) == 1 + assert "corrupted" in backup_files[0] + + +class TestBackupAndRemoveDatabase: + """Test the backup and remove database functionality.""" + + def test_backup_creates_timestamped_file(self): + """Test that backup creates a timestamped backup file.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + + # Create a file + with open(db_path, "w") as f: + f.write("test content") + + _backup_and_remove_database(db_path) + + # Original should be gone + assert not os.path.exists(db_path) + + # Backup should exist with timestamp + files = os.listdir(tmpdir) + assert len(files) == 1 + assert files[0].startswith("test.db.corrupted.") + + def test_removes_wal_and_shm_files(self): + """Test that WAL and SHM files are also removed.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + + # Create main db and WAL/SHM files + for suffix in ["", "-wal", "-shm"]: + with open(f"{db_path}{suffix}", "w") as f: + f.write("test") + + _backup_and_remove_database(db_path) + + # WAL and SHM should be removed + assert not os.path.exists(f"{db_path}-wal") + assert not os.path.exists(f"{db_path}-shm") + + +class TestDatabaseAutoRepairOnInit: + """Test that databases are auto-repaired on initialization.""" + + def test_event_bus_database_auto_repairs(self): + """Test that SqliteEventBusDatabase auto-repairs corrupted database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "events.db") + + # Create a corrupted file + with open(db_path, "w") as f: + f.write("corrupted!") + + # Should not raise, should auto-repair + db = SqliteEventBusDatabase(db_path, auto_repair=True) + + # Should be usable + db.publish_event("test", {"data": "test"}, "test-process") + + def test_evaluation_row_store_auto_repairs(self): + """Test that SqliteEvaluationRowStore auto-repairs corrupted database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "eval.db") + + # Create a corrupted file + with open(db_path, "w") as f: + f.write("corrupted!") + + # Should not raise, should auto-repair + store = SqliteEvaluationRowStore(db_path, auto_repair=True) + + # Should be usable - verify by checking that it has the expected table + cursor = store._db.execute_sql( + "SELECT name FROM sqlite_master WHERE type='table' AND name='evaluationrow'" + ) + result = cursor.fetchone() + assert result is not None + + +class TestConcurrencySafety: + """Test concurrent access to SQLite databases.""" + + def test_concurrent_writes_to_event_bus(self): + """Test that concurrent writes to event bus don't fail.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "concurrent.db") + db = SqliteEventBusDatabase(db_path) + + errors: List[Exception] = [] + num_threads = 10 + events_per_thread = 50 + + def write_events(thread_id: int): + try: + for i in range(events_per_thread): + db.publish_event( + f"event_{thread_id}_{i}", + {"thread": thread_id, "index": i}, + f"process_{thread_id}", + ) + except Exception as e: + errors.append(e) + + # Run concurrent writes + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(write_events, i) for i in range(num_threads)] + for future in futures: + future.result() + + # Should have no errors + assert len(errors) == 0, f"Concurrent write errors: {errors}" + + def test_concurrent_upserts_to_evaluation_store(self): + """Test that concurrent upserts to evaluation store don't fail.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "concurrent_eval.db") + store = SqliteEvaluationRowStore(db_path) + + errors: List[Exception] = [] + num_threads = 10 + upserts_per_thread = 20 + + def upsert_rows(thread_id: int): + try: + for i in range(upserts_per_thread): + rollout_id = f"rollout_{thread_id}_{i}" + data = { + "execution_metadata": {"rollout_id": rollout_id}, + "data": {"thread": thread_id, "index": i}, + } + store.upsert_row(data) + except Exception as e: + errors.append(e) + + # Run concurrent upserts + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(upsert_rows, i) for i in range(num_threads)] + for future in futures: + future.result() + + # Should have no errors + assert len(errors) == 0, f"Concurrent upsert errors: {errors}" + + # Verify all rows were written + all_rows = store.read_rows() + assert len(all_rows) == num_threads * upserts_per_thread + + def test_concurrent_reads_and_writes(self): + """Test that concurrent reads and writes work correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "rw_concurrent.db") + store = SqliteEvaluationRowStore(db_path) + + # Pre-populate with some data + for i in range(10): + store.upsert_row( + { + "execution_metadata": {"rollout_id": f"initial_{i}"}, + "data": {"initial": True}, + } + ) + + errors: List[Exception] = [] + read_counts: List[int] = [] + + def writer(): + try: + for i in range(50): + store.upsert_row( + { + "execution_metadata": {"rollout_id": f"write_{i}"}, + "data": {"written": True}, + } + ) + time.sleep(0.001) + except Exception as e: + errors.append(e) + + def reader(): + try: + for _ in range(100): + rows = store.read_rows() + read_counts.append(len(rows)) + time.sleep(0.001) + except Exception as e: + errors.append(e) + + # Run concurrent reads and writes + threads = [ + threading.Thread(target=writer), + threading.Thread(target=reader), + threading.Thread(target=reader), + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + # Should have no errors + assert len(errors) == 0, f"Concurrent read/write errors: {errors}" + + # All reads should have returned valid counts + assert all(count >= 10 for count in read_counts), "Reads should return at least initial rows" + + +class TestDatabaseCorruptedErrorClass: + """Test the DatabaseCorruptedError exception class.""" + + def test_error_contains_db_path(self): + """Test that error contains the database path.""" + original_error = Exception("original error") + error = DatabaseCorruptedError("/path/to/db.sqlite", original_error) + + assert error.db_path == "/path/to/db.sqlite" + assert error.original_error == original_error + assert "/path/to/db.sqlite" in str(error) + + def test_error_is_exception(self): + """Test that DatabaseCorruptedError is an Exception.""" + error = DatabaseCorruptedError("/path/to/db", Exception("test")) + assert isinstance(error, Exception) From 74ef7cfe7208a95762da4007867227bf9fb821a6 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Sat, 13 Dec 2025 01:29:42 -0800 Subject: [PATCH 3/4] Refactor database error handling to prevent deletion on transient errors and improve corruption detection. Add tests to ensure valid databases are not deleted on non-corruption DatabaseErrors. --- eval_protocol/cli_commands/logs.py | 1 - .../event_bus/sqlite_event_bus_database.py | 15 ++--- tests/test_sqlite_hardening.py | 57 +++++++++++++++++++ 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/eval_protocol/cli_commands/logs.py b/eval_protocol/cli_commands/logs.py index a952b789..63c48649 100644 --- a/eval_protocol/cli_commands/logs.py +++ b/eval_protocol/cli_commands/logs.py @@ -62,7 +62,6 @@ def _is_database_corruption_error(error: Exception) -> tuple[bool, str]: corruption_indicators = [ "file is not a database", "database disk image is malformed", - "database is locked", "unable to open database file", ] diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index f7d9bb5b..f148991d 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -66,19 +66,20 @@ def check_and_repair_database(db_path: str, auto_repair: bool = False) -> bool: except DatabaseError as e: error_str = str(e).lower() - if "file is not a database" in error_str or "database disk image is malformed" in error_str: + # Only treat specific SQLite corruption errors as corruption + corruption_indicators = [ + "file is not a database", + "database disk image is malformed", + "file is encrypted or is not a database", + ] + if any(indicator in error_str for indicator in corruption_indicators): logger.warning(f"Database file is corrupted: {db_path}") if auto_repair: _backup_and_remove_database(db_path) return True raise DatabaseCorruptedError(db_path, e) + # For other DatabaseErrors (locks, busy, etc.), re-raise without deleting raise - except Exception as e: - logger.warning(f"Error checking database {db_path}: {e}") - if auto_repair: - _backup_and_remove_database(db_path) - return True - raise DatabaseCorruptedError(db_path, e) def _backup_and_remove_database(db_path: str) -> None: diff --git a/tests/test_sqlite_hardening.py b/tests/test_sqlite_hardening.py index 301240ff..a4f3f863 100644 --- a/tests/test_sqlite_hardening.py +++ b/tests/test_sqlite_hardening.py @@ -194,6 +194,63 @@ def test_corrupted_file_backup_created(self): assert len(backup_files) == 1 assert "corrupted" in backup_files[0] + def test_transient_errors_do_not_delete_database(self): + """Test that transient I/O errors (like PermissionError) don't trigger database deletion. + + This is a regression test for a bug where the catch-all Exception handler + would delete valid databases on transient errors like PermissionError, + OSError, or temporary lock conflicts. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "valid.db") + + # Create a valid database + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, data TEXT)") + conn.execute("INSERT INTO test VALUES (1, 'important data')") + conn.commit() + conn.close() + + # Verify the database is valid first + result = check_and_repair_database(db_path) + assert result is True + assert os.path.exists(db_path) + + # The database should still exist and be valid + conn = sqlite3.connect(db_path) + cursor = conn.execute("SELECT data FROM test WHERE id=1") + row = cursor.fetchone() + conn.close() + assert row[0] == "important data" + + def test_database_error_without_corruption_indicator_is_not_auto_repaired(self): + """Test that DatabaseError without corruption indicators is re-raised, not auto-repaired.""" + from unittest.mock import patch, MagicMock + from peewee import DatabaseError + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "locked.db") + + # Create a valid database first + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER)") + conn.close() + + # Mock SqliteDatabase to raise a non-corruption DatabaseError (e.g., database locked) + with patch("eval_protocol.event_bus.sqlite_event_bus_database.SqliteDatabase") as mock_db_class: + mock_db = MagicMock() + mock_db_class.return_value = mock_db + mock_db.connect.side_effect = DatabaseError("database is locked") + + # Should re-raise the error, not delete the database + with pytest.raises(DatabaseError) as exc_info: + check_and_repair_database(db_path, auto_repair=True) + + assert "locked" in str(exc_info.value) + + # Database file should still exist (not deleted) + assert os.path.exists(db_path) + class TestBackupAndRemoveDatabase: """Test the backup and remove database functionality.""" From f05a1ff76a0229b2a813474fa1e081c8c644c986 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Sat, 13 Dec 2025 01:31:40 -0800 Subject: [PATCH 4/4] Add checkpointing functionality to safely copy SQLite databases This update introduces a new function, _checkpoint_and_copy_database, which ensures that all data in the Write-Ahead Logging (WAL) file is flushed to the main database file before copying. This function is now utilized in the SQLResource class for database forking, checkpointing, and restoring operations, enhancing data integrity during these processes. --- eval_protocol/agent/resources/sql_resource.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/eval_protocol/agent/resources/sql_resource.py b/eval_protocol/agent/resources/sql_resource.py index ed8c0360..7faa5e59 100644 --- a/eval_protocol/agent/resources/sql_resource.py +++ b/eval_protocol/agent/resources/sql_resource.py @@ -27,6 +27,35 @@ def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None: conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory +def _checkpoint_and_copy_database( + source_path: Path, dest_path: Path, timeout: int = SQLITE_CONNECTION_TIMEOUT +) -> None: + """ + Safely copy a SQLite database by checkpointing WAL first. + + In WAL mode, data may exist in the -wal file that hasn't been written + to the main database file. This function performs a TRUNCATE checkpoint + to flush all WAL data to the main file before copying, ensuring a + complete and consistent copy. + + Args: + source_path: Path to the source database file. + dest_path: Path where the copy should be created. + timeout: Connection timeout in seconds. + """ + # First, checkpoint the WAL to ensure all data is in the main file + conn = sqlite3.connect(str(source_path), timeout=timeout) + try: + # TRUNCATE mode: checkpoint and truncate the WAL file to zero bytes + # This ensures all data is flushed to the main database file + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") + finally: + conn.close() + + # Now safely copy just the main database file + shutil.copyfile(str(source_path), str(dest_path)) + + class SQLResource(ForkableResource): """ A ForkableResource for managing SQL database states, primarily SQLite. @@ -134,7 +163,8 @@ async def fork(self) -> "SQLResource": forked_db_name = f"fork_{uuid.uuid4().hex}.sqlite" forked_resource._db_path = self._temp_dir / forked_db_name - shutil.copyfile(str(self._db_path), str(forked_resource._db_path)) + # Use checkpoint-and-copy to ensure WAL data is flushed before copying + _checkpoint_and_copy_database(self._db_path, forked_resource._db_path) return forked_resource async def checkpoint(self) -> Dict[str, Any]: @@ -148,7 +178,8 @@ async def checkpoint(self) -> Dict[str, Any]: checkpoint_name = f"checkpoint_{self._db_path.stem}_{uuid.uuid4().hex}.sqlite" checkpoint_path = self._temp_dir / checkpoint_name - shutil.copyfile(str(self._db_path), str(checkpoint_path)) + # Use checkpoint-and-copy to ensure WAL data is flushed before copying + _checkpoint_and_copy_database(self._db_path, checkpoint_path) return {"db_type": "sqlite", "checkpoint_path": str(checkpoint_path)} async def restore(self, state_data: Dict[str, Any]) -> None: @@ -170,7 +201,8 @@ async def restore(self, state_data: Dict[str, Any]) -> None: if not self._db_path: self._db_path = self._temp_dir / f"restored_{uuid.uuid4().hex}.sqlite" - shutil.copyfile(str(checkpoint_path), str(self._db_path)) + # Use checkpoint-and-copy to ensure WAL data is flushed before copying + _checkpoint_and_copy_database(checkpoint_path, self._db_path) self._base_db_path = self._db_path # The restored state becomes the new base for future forks async def step(self, action_name: str, action_params: Dict[str, Any]) -> Any: