Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions eval_protocol/agent/resources/sql_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,50 @@
from ..resource_abc import ForkableResource


# SQLite connection settings for hardened concurrency safety
SQLITE_CONNECTION_TIMEOUT = 30 # 30 seconds


def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None:
"""Apply hardened SQLite pragmas for concurrency safety."""
conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging
conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and performance
conn.execute("PRAGMA busy_timeout=30000") # 30 second timeout
conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint every 1000 pages
conn.execute("PRAGMA cache_size=-64000") # 64MB cache
conn.execute("PRAGMA foreign_keys=ON") # Enable foreign key constraints
conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory


def _checkpoint_and_copy_database(
source_path: Path, dest_path: Path, timeout: int = SQLITE_CONNECTION_TIMEOUT
) -> None:
"""
Safely copy a SQLite database by checkpointing WAL first.

In WAL mode, data may exist in the -wal file that hasn't been written
to the main database file. This function performs a TRUNCATE checkpoint
to flush all WAL data to the main file before copying, ensuring a
complete and consistent copy.

Args:
source_path: Path to the source database file.
dest_path: Path where the copy should be created.
timeout: Connection timeout in seconds.
"""
# First, checkpoint the WAL to ensure all data is in the main file
conn = sqlite3.connect(str(source_path), timeout=timeout)
try:
# TRUNCATE mode: checkpoint and truncate the WAL file to zero bytes
# This ensures all data is flushed to the main database file
conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
finally:
conn.close()

# Now safely copy just the main database file
shutil.copyfile(str(source_path), str(dest_path))


class SQLResource(ForkableResource):
"""
A ForkableResource for managing SQL database states, primarily SQLite.
Expand All @@ -20,6 +64,8 @@ class SQLResource(ForkableResource):
and seed data, forked (by copying the DB file), checkpointed (by copying),
and restored.

Uses hardened SQLite settings for concurrency safety.

Attributes:
_config (Dict[str, Any]): Configuration for the resource.
_db_path (Optional[Path]): Path to the current SQLite database file.
Expand All @@ -38,8 +84,14 @@ def __init__(self) -> None:
def _get_db_connection(self) -> sqlite3.Connection:
if not self._db_path:
raise ConnectionError("Database path not set. Call setup() or fork() first.")
# Set timeout to prevent indefinite hangs
return sqlite3.connect(str(self._db_path), timeout=10)
# Set timeout to prevent indefinite hangs with hardened settings
conn = sqlite3.connect(
str(self._db_path),
timeout=SQLITE_CONNECTION_TIMEOUT,
isolation_level="DEFERRED", # Better for concurrent access
)
_apply_hardened_pragmas(conn)
return conn

async def setup(self, config: Dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -111,7 +163,8 @@ async def fork(self) -> "SQLResource":
forked_db_name = f"fork_{uuid.uuid4().hex}.sqlite"
forked_resource._db_path = self._temp_dir / forked_db_name

shutil.copyfile(str(self._db_path), str(forked_resource._db_path))
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
_checkpoint_and_copy_database(self._db_path, forked_resource._db_path)
return forked_resource

async def checkpoint(self) -> Dict[str, Any]:
Expand All @@ -125,7 +178,8 @@ async def checkpoint(self) -> Dict[str, Any]:

checkpoint_name = f"checkpoint_{self._db_path.stem}_{uuid.uuid4().hex}.sqlite"
checkpoint_path = self._temp_dir / checkpoint_name
shutil.copyfile(str(self._db_path), str(checkpoint_path))
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
_checkpoint_and_copy_database(self._db_path, checkpoint_path)
return {"db_type": "sqlite", "checkpoint_path": str(checkpoint_path)}

async def restore(self, state_data: Dict[str, Any]) -> None:
Expand All @@ -147,7 +201,8 @@ async def restore(self, state_data: Dict[str, Any]) -> None:
if not self._db_path:
self._db_path = self._temp_dir / f"restored_{uuid.uuid4().hex}.sqlite"

shutil.copyfile(str(checkpoint_path), str(self._db_path))
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
_checkpoint_and_copy_database(checkpoint_path, self._db_path)
self._base_db_path = self._db_path # The restored state becomes the new base for future forks

async def step(self, action_name: str, action_params: Dict[str, Any]) -> Any:
Expand Down
119 changes: 104 additions & 15 deletions eval_protocol/cli_commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,81 @@

import os
from ..utils.logs_server import serve_logs
from ..event_bus.sqlite_event_bus_database import DatabaseCorruptedError, _backup_and_remove_database


def _handle_database_corruption(db_path: str) -> bool:
"""
Handle database corruption by prompting user to fix it.

Args:
db_path: Path to the corrupted database

Returns:
True if user chose to fix and database was reset, False otherwise
"""
print("\n" + "=" * 60)
print("⚠️ DATABASE CORRUPTION DETECTED")
print("=" * 60)
print(f"\nThe database file at:\n {db_path}\n")
print("appears to be corrupted or is not a valid SQLite database.")
print("\nThis can happen due to:")
print(" • Incomplete writes during a crash")
print(" • Concurrent access issues")
print(" • File system errors")
print("\n" + "-" * 60)
print("Would you like to automatically fix this?")
print(" • The corrupted file will be backed up")
print(" • A fresh database will be created")
print(" • You will lose existing log data, but can continue using the tool")
print("-" * 60)

try:
response = input("\nFix database automatically? [Y/n]: ").strip().lower()
if response in ("", "y", "yes"):
_backup_and_remove_database(db_path)
print("\n✅ Database has been reset. Restarting server...")
return True
else:
print("\n❌ Database repair cancelled.")
print(f" You can manually delete the corrupted file: {db_path}")
return False
except (EOFError, KeyboardInterrupt):
print("\n❌ Database repair cancelled.")
return False


def _is_database_corruption_error(error: Exception) -> tuple[bool, str]:
"""
Check if an exception is related to database corruption.

Returns:
Tuple of (is_corruption_error, db_path)
"""
error_str = str(error).lower()
corruption_indicators = [
"file is not a database",
"database disk image is malformed",
"unable to open database file",
]

for indicator in corruption_indicators:
if indicator in error_str:
# Try to find the database path
from ..directory_utils import find_eval_protocol_dir

try:
eval_protocol_dir = find_eval_protocol_dir()
db_path = os.path.join(eval_protocol_dir, "logs.db")
return True, db_path
except Exception:
return True, ""

# Check if it's a DatabaseCorruptedError
if isinstance(error, DatabaseCorruptedError):
return True, error.db_path

return False, ""


def logs_command(args):
Expand Down Expand Up @@ -40,18 +115,32 @@ def logs_command(args):
or "https://tracing.fireworks.ai"
)

try:
serve_logs(
port=args.port,
elasticsearch_config=elasticsearch_config,
debug=args.debug,
backend="fireworks" if use_fireworks else "elasticsearch",
fireworks_base_url=fireworks_base_url if use_fireworks else None,
)
return 0
except KeyboardInterrupt:
print("\n🛑 Server stopped by user")
return 0
except Exception as e:
print(f"❌ Error starting server: {e}")
return 1
max_retries = 2
for attempt in range(max_retries):
try:
serve_logs(
port=args.port,
elasticsearch_config=elasticsearch_config,
debug=args.debug,
backend="fireworks" if use_fireworks else "elasticsearch",
fireworks_base_url=fireworks_base_url if use_fireworks else None,
)
return 0
except KeyboardInterrupt:
print("\n🛑 Server stopped by user")
return 0
except Exception as e:
is_corruption, db_path = _is_database_corruption_error(e)

if is_corruption and db_path and attempt < max_retries - 1:
if _handle_database_corruption(db_path):
# User chose to fix, retry
continue
else:
# User declined fix
return 1

print(f"❌ Error starting server: {e}")
return 1

return 1
21 changes: 17 additions & 4 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
from typing import List, Optional

from peewee import CharField, Model, SqliteDatabase
from peewee import CharField, DatabaseError, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.sqlite_event_bus_database import (
SQLITE_HARDENED_PRAGMAS,
DatabaseCorruptedError,
check_and_repair_database,
)
from eval_protocol.models import EvaluationRow


Expand All @@ -12,12 +17,20 @@ class SqliteEvaluationRowStore:
Lightweight reusable SQLite store for evaluation rows.

Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
Uses hardened SQLite settings for concurrency safety.
"""

def __init__(self, db_path: str):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
def __init__(self, db_path: str, auto_repair: bool = True):
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db_path = db_path
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})

# Check and optionally repair corrupted database
check_and_repair_database(db_path, auto_repair=auto_repair)

# Use hardened pragmas for concurrency safety
self._db = SqliteDatabase(self._db_path, pragmas=SQLITE_HARDENED_PRAGMAS)

class BaseModel(Model):
class Meta:
Expand Down
5 changes: 5 additions & 0 deletions eval_protocol/event_bus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Global event bus instance - uses SqliteEventBus for cross-process functionality
from typing import Any, Callable
from eval_protocol.event_bus.event_bus import EventBus
from eval_protocol.event_bus.sqlite_event_bus_database import (
DatabaseCorruptedError,
check_and_repair_database,
SQLITE_HARDENED_PRAGMAS,
)


def _get_default_event_bus():
Expand Down
Loading
Loading