Skip to content

Commit 64025a4

Browse files
author
Dylan Huang
committed
sqlite hardening
1 parent 50ea6db commit 64025a4

File tree

6 files changed

+680
-25
lines changed

6 files changed

+680
-25
lines changed

eval_protocol/agent/resources/sql_resource.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@
1212
from ..resource_abc import ForkableResource
1313

1414

15+
# SQLite connection settings for hardened concurrency safety
16+
SQLITE_CONNECTION_TIMEOUT = 30 # 30 seconds
17+
18+
19+
def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None:
20+
"""Apply hardened SQLite pragmas for concurrency safety."""
21+
conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging
22+
conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and performance
23+
conn.execute("PRAGMA busy_timeout=30000") # 30 second timeout
24+
conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint every 1000 pages
25+
conn.execute("PRAGMA cache_size=-64000") # 64MB cache
26+
conn.execute("PRAGMA foreign_keys=ON") # Enable foreign key constraints
27+
conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory
28+
29+
1530
class SQLResource(ForkableResource):
1631
"""
1732
A ForkableResource for managing SQL database states, primarily SQLite.
@@ -20,6 +35,8 @@ class SQLResource(ForkableResource):
2035
and seed data, forked (by copying the DB file), checkpointed (by copying),
2136
and restored.
2237
38+
Uses hardened SQLite settings for concurrency safety.
39+
2340
Attributes:
2441
_config (Dict[str, Any]): Configuration for the resource.
2542
_db_path (Optional[Path]): Path to the current SQLite database file.
@@ -38,8 +55,14 @@ def __init__(self) -> None:
3855
def _get_db_connection(self) -> sqlite3.Connection:
3956
if not self._db_path:
4057
raise ConnectionError("Database path not set. Call setup() or fork() first.")
41-
# Set timeout to prevent indefinite hangs
42-
return sqlite3.connect(str(self._db_path), timeout=10)
58+
# Set timeout to prevent indefinite hangs with hardened settings
59+
conn = sqlite3.connect(
60+
str(self._db_path),
61+
timeout=SQLITE_CONNECTION_TIMEOUT,
62+
isolation_level="DEFERRED", # Better for concurrent access
63+
)
64+
_apply_hardened_pragmas(conn)
65+
return conn
4366

4467
async def setup(self, config: Dict[str, Any]) -> None:
4568
"""

eval_protocol/cli_commands/logs.py

Lines changed: 105 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,82 @@
77

88
import os
99
from ..utils.logs_server import serve_logs
10+
from ..event_bus.sqlite_event_bus_database import DatabaseCorruptedError, _backup_and_remove_database
11+
12+
13+
def _handle_database_corruption(db_path: str) -> bool:
14+
"""
15+
Handle database corruption by prompting user to fix it.
16+
17+
Args:
18+
db_path: Path to the corrupted database
19+
20+
Returns:
21+
True if user chose to fix and database was reset, False otherwise
22+
"""
23+
print("\n" + "=" * 60)
24+
print("⚠️ DATABASE CORRUPTION DETECTED")
25+
print("=" * 60)
26+
print(f"\nThe database file at:\n {db_path}\n")
27+
print("appears to be corrupted or is not a valid SQLite database.")
28+
print("\nThis can happen due to:")
29+
print(" • Incomplete writes during a crash")
30+
print(" • Concurrent access issues")
31+
print(" • File system errors")
32+
print("\n" + "-" * 60)
33+
print("Would you like to automatically fix this?")
34+
print(" • The corrupted file will be backed up")
35+
print(" • A fresh database will be created")
36+
print(" • You will lose existing log data, but can continue using the tool")
37+
print("-" * 60)
38+
39+
try:
40+
response = input("\nFix database automatically? [Y/n]: ").strip().lower()
41+
if response in ("", "y", "yes"):
42+
_backup_and_remove_database(db_path)
43+
print("\n✅ Database has been reset. Restarting server...")
44+
return True
45+
else:
46+
print("\n❌ Database repair cancelled.")
47+
print(f" You can manually delete the corrupted file: {db_path}")
48+
return False
49+
except (EOFError, KeyboardInterrupt):
50+
print("\n❌ Database repair cancelled.")
51+
return False
52+
53+
54+
def _is_database_corruption_error(error: Exception) -> tuple[bool, str]:
55+
"""
56+
Check if an exception is related to database corruption.
57+
58+
Returns:
59+
Tuple of (is_corruption_error, db_path)
60+
"""
61+
error_str = str(error).lower()
62+
corruption_indicators = [
63+
"file is not a database",
64+
"database disk image is malformed",
65+
"database is locked",
66+
"unable to open database file",
67+
]
68+
69+
for indicator in corruption_indicators:
70+
if indicator in error_str:
71+
# Try to find the database path
72+
from ..directory_utils import find_eval_protocol_dir
73+
74+
try:
75+
eval_protocol_dir = find_eval_protocol_dir()
76+
db_path = os.path.join(eval_protocol_dir, "logs.db")
77+
return True, db_path
78+
except Exception:
79+
return True, ""
80+
81+
# Check if it's a DatabaseCorruptedError
82+
if isinstance(error, DatabaseCorruptedError):
83+
return True, error.db_path
84+
85+
return False, ""
1086

1187

1288
def logs_command(args):
@@ -40,18 +116,32 @@ def logs_command(args):
40116
or "https://tracing.fireworks.ai"
41117
)
42118

43-
try:
44-
serve_logs(
45-
port=args.port,
46-
elasticsearch_config=elasticsearch_config,
47-
debug=args.debug,
48-
backend="fireworks" if use_fireworks else "elasticsearch",
49-
fireworks_base_url=fireworks_base_url if use_fireworks else None,
50-
)
51-
return 0
52-
except KeyboardInterrupt:
53-
print("\n🛑 Server stopped by user")
54-
return 0
55-
except Exception as e:
56-
print(f"❌ Error starting server: {e}")
57-
return 1
119+
max_retries = 2
120+
for attempt in range(max_retries):
121+
try:
122+
serve_logs(
123+
port=args.port,
124+
elasticsearch_config=elasticsearch_config,
125+
debug=args.debug,
126+
backend="fireworks" if use_fireworks else "elasticsearch",
127+
fireworks_base_url=fireworks_base_url if use_fireworks else None,
128+
)
129+
return 0
130+
except KeyboardInterrupt:
131+
print("\n🛑 Server stopped by user")
132+
return 0
133+
except Exception as e:
134+
is_corruption, db_path = _is_database_corruption_error(e)
135+
136+
if is_corruption and db_path and attempt < max_retries - 1:
137+
if _handle_database_corruption(db_path):
138+
# User chose to fix, retry
139+
continue
140+
else:
141+
# User declined fix
142+
return 1
143+
144+
print(f"❌ Error starting server: {e}")
145+
return 1
146+
147+
return 1

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import os
22
from typing import List, Optional
33

4-
from peewee import CharField, Model, SqliteDatabase
4+
from peewee import CharField, DatabaseError, Model, SqliteDatabase
55
from playhouse.sqlite_ext import JSONField
66

7+
from eval_protocol.event_bus.sqlite_event_bus_database import (
8+
SQLITE_HARDENED_PRAGMAS,
9+
DatabaseCorruptedError,
10+
check_and_repair_database,
11+
)
712
from eval_protocol.models import EvaluationRow
813

914

@@ -12,12 +17,20 @@ class SqliteEvaluationRowStore:
1217
Lightweight reusable SQLite store for evaluation rows.
1318
1419
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
20+
Uses hardened SQLite settings for concurrency safety.
1521
"""
1622

17-
def __init__(self, db_path: str):
18-
os.makedirs(os.path.dirname(db_path), exist_ok=True)
23+
def __init__(self, db_path: str, auto_repair: bool = True):
24+
db_dir = os.path.dirname(db_path)
25+
if db_dir:
26+
os.makedirs(db_dir, exist_ok=True)
1927
self._db_path = db_path
20-
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})
28+
29+
# Check and optionally repair corrupted database
30+
check_and_repair_database(db_path, auto_repair=auto_repair)
31+
32+
# Use hardened pragmas for concurrency safety
33+
self._db = SqliteDatabase(self._db_path, pragmas=SQLITE_HARDENED_PRAGMAS)
2134

2235
class BaseModel(Model):
2336
class Meta:

eval_protocol/event_bus/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Global event bus instance - uses SqliteEventBus for cross-process functionality
22
from typing import Any, Callable
33
from eval_protocol.event_bus.event_bus import EventBus
4+
from eval_protocol.event_bus.sqlite_event_bus_database import (
5+
DatabaseCorruptedError,
6+
check_and_repair_database,
7+
SQLITE_HARDENED_PRAGMAS,
8+
)
49

510

611
def _get_default_event_bus():

eval_protocol/event_bus/sqlite_event_bus_database.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,125 @@
1+
import os
12
import time
23
from typing import Any, List
34
from uuid import uuid4
45

5-
from peewee import BooleanField, CharField, DateTimeField, Model, SqliteDatabase
6+
from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, SqliteDatabase
67
from playhouse.sqlite_ext import JSONField
78

89
from eval_protocol.event_bus.logger import logger
910

1011

12+
# SQLite pragmas for hardened concurrency safety
13+
SQLITE_HARDENED_PRAGMAS = {
14+
"journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes
15+
"synchronous": "normal", # Balance between safety and performance
16+
"busy_timeout": 30000, # 30 second timeout for locked database
17+
"wal_autocheckpoint": 1000, # Checkpoint every 1000 pages
18+
"cache_size": -64000, # 64MB cache (negative = KB)
19+
"foreign_keys": 1, # Enable foreign key constraints
20+
"temp_store": "memory", # Store temp tables in memory
21+
}
22+
23+
24+
class DatabaseCorruptedError(Exception):
25+
"""Raised when the database file is corrupted or not a valid SQLite database."""
26+
27+
def __init__(self, db_path: str, original_error: Exception):
28+
self.db_path = db_path
29+
self.original_error = original_error
30+
super().__init__(f"Database file is corrupted: {db_path}. Original error: {original_error}")
31+
32+
33+
def check_and_repair_database(db_path: str, auto_repair: bool = False) -> bool:
34+
"""
35+
Check if a database file is valid and optionally repair it.
36+
37+
Args:
38+
db_path: Path to the database file
39+
auto_repair: If True, automatically delete and recreate corrupted database
40+
41+
Returns:
42+
True if database is valid or was repaired, False otherwise
43+
44+
Raises:
45+
DatabaseCorruptedError: If database is corrupted and auto_repair is False
46+
"""
47+
if not os.path.exists(db_path):
48+
return True # New database, nothing to check
49+
50+
try:
51+
# Try to open the database and run an integrity check
52+
test_db = SqliteDatabase(db_path, pragmas={"busy_timeout": 5000})
53+
test_db.connect()
54+
cursor = test_db.execute_sql("PRAGMA integrity_check")
55+
result = cursor.fetchone()
56+
test_db.close()
57+
58+
if result and result[0] == "ok":
59+
return True
60+
else:
61+
logger.warning(f"Database integrity check failed for {db_path}: {result}")
62+
if auto_repair:
63+
_backup_and_remove_database(db_path)
64+
return True
65+
raise DatabaseCorruptedError(db_path, Exception(f"Integrity check failed: {result}"))
66+
67+
except DatabaseError as e:
68+
error_str = str(e).lower()
69+
if "file is not a database" in error_str or "database disk image is malformed" in error_str:
70+
logger.warning(f"Database file is corrupted: {db_path}")
71+
if auto_repair:
72+
_backup_and_remove_database(db_path)
73+
return True
74+
raise DatabaseCorruptedError(db_path, e)
75+
raise
76+
except Exception as e:
77+
logger.warning(f"Error checking database {db_path}: {e}")
78+
if auto_repair:
79+
_backup_and_remove_database(db_path)
80+
return True
81+
raise DatabaseCorruptedError(db_path, e)
82+
83+
84+
def _backup_and_remove_database(db_path: str) -> None:
85+
"""Backup a corrupted database file and remove it."""
86+
backup_path = f"{db_path}.corrupted.{int(time.time())}"
87+
try:
88+
os.rename(db_path, backup_path)
89+
logger.info(f"Backed up corrupted database to: {backup_path}")
90+
except OSError as e:
91+
logger.warning(f"Failed to backup corrupted database, removing: {e}")
92+
try:
93+
os.remove(db_path)
94+
except OSError:
95+
pass
96+
97+
# Also try to remove WAL and SHM files if they exist
98+
for suffix in ["-wal", "-shm"]:
99+
wal_file = f"{db_path}{suffix}"
100+
if os.path.exists(wal_file):
101+
try:
102+
os.remove(wal_file)
103+
except OSError:
104+
pass
105+
106+
11107
class SqliteEventBusDatabase:
12108
"""SQLite database for cross-process event communication."""
13109

14-
def __init__(self, db_path: str):
110+
def __init__(self, db_path: str, auto_repair: bool = True):
15111
self._db_path = db_path
16-
self._db = SqliteDatabase(db_path)
112+
113+
# Ensure directory exists
114+
db_dir = os.path.dirname(db_path)
115+
if db_dir:
116+
os.makedirs(db_dir, exist_ok=True)
117+
118+
# Check and optionally repair corrupted database
119+
check_and_repair_database(db_path, auto_repair=auto_repair)
120+
121+
# Initialize database with hardened concurrency settings
122+
self._db = SqliteDatabase(db_path, pragmas=SQLITE_HARDENED_PRAGMAS)
17123

18124
class BaseModel(Model):
19125
class Meta:
@@ -29,7 +135,8 @@ class Event(BaseModel): # type: ignore
29135

30136
self._Event = Event
31137
self._db.connect()
32-
self._db.create_tables([Event])
138+
# Use safe=True to avoid errors when tables already exist
139+
self._db.create_tables([Event], safe=True)
33140

34141
def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
35142
"""Publish an event to the database."""

0 commit comments

Comments
 (0)