Skip to content

Commit c2f8e24

Browse files
author
Dylan Huang
authored
Harden SQlite (#367)
* add tests to pyright * sqlite hardening * Refactor database error handling to prevent deletion on transient errors and improve corruption detection. Add tests to ensure valid databases are not deleted on non-corruption DatabaseErrors. * Add checkpointing functionality to safely copy SQLite databases This update introduces a new function, _checkpoint_and_copy_database, which ensures that all data in the Write-Ahead Logging (WAL) file is flushed to the main database file before copying. This function is now utilized in the SQLResource class for database forking, checkpointing, and restoring operations, enhancing data integrity during these processes.
1 parent bf8f228 commit c2f8e24

File tree

7 files changed

+774
-30
lines changed

7 files changed

+774
-30
lines changed

eval_protocol/agent/resources/sql_resource.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,50 @@
1212
from ..resource_abc import ForkableResource
1313

1414

15+
# SQLite connection settings for hardened concurrency safety
16+
SQLITE_CONNECTION_TIMEOUT = 30 # 30 seconds
17+
18+
19+
def _apply_hardened_pragmas(conn: sqlite3.Connection) -> None:
20+
"""Apply hardened SQLite pragmas for concurrency safety."""
21+
conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging
22+
conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and performance
23+
conn.execute("PRAGMA busy_timeout=30000") # 30 second timeout
24+
conn.execute("PRAGMA wal_autocheckpoint=1000") # Checkpoint every 1000 pages
25+
conn.execute("PRAGMA cache_size=-64000") # 64MB cache
26+
conn.execute("PRAGMA foreign_keys=ON") # Enable foreign key constraints
27+
conn.execute("PRAGMA temp_store=MEMORY") # Store temp tables in memory
28+
29+
30+
def _checkpoint_and_copy_database(
31+
source_path: Path, dest_path: Path, timeout: int = SQLITE_CONNECTION_TIMEOUT
32+
) -> None:
33+
"""
34+
Safely copy a SQLite database by checkpointing WAL first.
35+
36+
In WAL mode, data may exist in the -wal file that hasn't been written
37+
to the main database file. This function performs a TRUNCATE checkpoint
38+
to flush all WAL data to the main file before copying, ensuring a
39+
complete and consistent copy.
40+
41+
Args:
42+
source_path: Path to the source database file.
43+
dest_path: Path where the copy should be created.
44+
timeout: Connection timeout in seconds.
45+
"""
46+
# First, checkpoint the WAL to ensure all data is in the main file
47+
conn = sqlite3.connect(str(source_path), timeout=timeout)
48+
try:
49+
# TRUNCATE mode: checkpoint and truncate the WAL file to zero bytes
50+
# This ensures all data is flushed to the main database file
51+
conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
52+
finally:
53+
conn.close()
54+
55+
# Now safely copy just the main database file
56+
shutil.copyfile(str(source_path), str(dest_path))
57+
58+
1559
class SQLResource(ForkableResource):
1660
"""
1761
A ForkableResource for managing SQL database states, primarily SQLite.
@@ -20,6 +64,8 @@ class SQLResource(ForkableResource):
2064
and seed data, forked (by copying the DB file), checkpointed (by copying),
2165
and restored.
2266
67+
Uses hardened SQLite settings for concurrency safety.
68+
2369
Attributes:
2470
_config (Dict[str, Any]): Configuration for the resource.
2571
_db_path (Optional[Path]): Path to the current SQLite database file.
@@ -38,8 +84,14 @@ def __init__(self) -> None:
3884
def _get_db_connection(self) -> sqlite3.Connection:
3985
if not self._db_path:
4086
raise ConnectionError("Database path not set. Call setup() or fork() first.")
41-
# Set timeout to prevent indefinite hangs
42-
return sqlite3.connect(str(self._db_path), timeout=10)
87+
# Set timeout to prevent indefinite hangs with hardened settings
88+
conn = sqlite3.connect(
89+
str(self._db_path),
90+
timeout=SQLITE_CONNECTION_TIMEOUT,
91+
isolation_level="DEFERRED", # Better for concurrent access
92+
)
93+
_apply_hardened_pragmas(conn)
94+
return conn
4395

4496
async def setup(self, config: Dict[str, Any]) -> None:
4597
"""
@@ -111,7 +163,8 @@ async def fork(self) -> "SQLResource":
111163
forked_db_name = f"fork_{uuid.uuid4().hex}.sqlite"
112164
forked_resource._db_path = self._temp_dir / forked_db_name
113165

114-
shutil.copyfile(str(self._db_path), str(forked_resource._db_path))
166+
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
167+
_checkpoint_and_copy_database(self._db_path, forked_resource._db_path)
115168
return forked_resource
116169

117170
async def checkpoint(self) -> Dict[str, Any]:
@@ -125,7 +178,8 @@ async def checkpoint(self) -> Dict[str, Any]:
125178

126179
checkpoint_name = f"checkpoint_{self._db_path.stem}_{uuid.uuid4().hex}.sqlite"
127180
checkpoint_path = self._temp_dir / checkpoint_name
128-
shutil.copyfile(str(self._db_path), str(checkpoint_path))
181+
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
182+
_checkpoint_and_copy_database(self._db_path, checkpoint_path)
129183
return {"db_type": "sqlite", "checkpoint_path": str(checkpoint_path)}
130184

131185
async def restore(self, state_data: Dict[str, Any]) -> None:
@@ -147,7 +201,8 @@ async def restore(self, state_data: Dict[str, Any]) -> None:
147201
if not self._db_path:
148202
self._db_path = self._temp_dir / f"restored_{uuid.uuid4().hex}.sqlite"
149203

150-
shutil.copyfile(str(checkpoint_path), str(self._db_path))
204+
# Use checkpoint-and-copy to ensure WAL data is flushed before copying
205+
_checkpoint_and_copy_database(checkpoint_path, self._db_path)
151206
self._base_db_path = self._db_path # The restored state becomes the new base for future forks
152207

153208
async def step(self, action_name: str, action_params: Dict[str, Any]) -> Any:

eval_protocol/cli_commands/logs.py

Lines changed: 104 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,81 @@
77

88
import os
99
from ..utils.logs_server import serve_logs
10+
from ..event_bus.sqlite_event_bus_database import DatabaseCorruptedError, _backup_and_remove_database
11+
12+
13+
def _handle_database_corruption(db_path: str) -> bool:
14+
"""
15+
Handle database corruption by prompting user to fix it.
16+
17+
Args:
18+
db_path: Path to the corrupted database
19+
20+
Returns:
21+
True if user chose to fix and database was reset, False otherwise
22+
"""
23+
print("\n" + "=" * 60)
24+
print("⚠️ DATABASE CORRUPTION DETECTED")
25+
print("=" * 60)
26+
print(f"\nThe database file at:\n {db_path}\n")
27+
print("appears to be corrupted or is not a valid SQLite database.")
28+
print("\nThis can happen due to:")
29+
print(" • Incomplete writes during a crash")
30+
print(" • Concurrent access issues")
31+
print(" • File system errors")
32+
print("\n" + "-" * 60)
33+
print("Would you like to automatically fix this?")
34+
print(" • The corrupted file will be backed up")
35+
print(" • A fresh database will be created")
36+
print(" • You will lose existing log data, but can continue using the tool")
37+
print("-" * 60)
38+
39+
try:
40+
response = input("\nFix database automatically? [Y/n]: ").strip().lower()
41+
if response in ("", "y", "yes"):
42+
_backup_and_remove_database(db_path)
43+
print("\n✅ Database has been reset. Restarting server...")
44+
return True
45+
else:
46+
print("\n❌ Database repair cancelled.")
47+
print(f" You can manually delete the corrupted file: {db_path}")
48+
return False
49+
except (EOFError, KeyboardInterrupt):
50+
print("\n❌ Database repair cancelled.")
51+
return False
52+
53+
54+
def _is_database_corruption_error(error: Exception) -> tuple[bool, str]:
55+
"""
56+
Check if an exception is related to database corruption.
57+
58+
Returns:
59+
Tuple of (is_corruption_error, db_path)
60+
"""
61+
error_str = str(error).lower()
62+
corruption_indicators = [
63+
"file is not a database",
64+
"database disk image is malformed",
65+
"unable to open database file",
66+
]
67+
68+
for indicator in corruption_indicators:
69+
if indicator in error_str:
70+
# Try to find the database path
71+
from ..directory_utils import find_eval_protocol_dir
72+
73+
try:
74+
eval_protocol_dir = find_eval_protocol_dir()
75+
db_path = os.path.join(eval_protocol_dir, "logs.db")
76+
return True, db_path
77+
except Exception:
78+
return True, ""
79+
80+
# Check if it's a DatabaseCorruptedError
81+
if isinstance(error, DatabaseCorruptedError):
82+
return True, error.db_path
83+
84+
return False, ""
1085

1186

1287
def logs_command(args):
@@ -40,18 +115,32 @@ def logs_command(args):
40115
or "https://tracing.fireworks.ai"
41116
)
42117

43-
try:
44-
serve_logs(
45-
port=args.port,
46-
elasticsearch_config=elasticsearch_config,
47-
debug=args.debug,
48-
backend="fireworks" if use_fireworks else "elasticsearch",
49-
fireworks_base_url=fireworks_base_url if use_fireworks else None,
50-
)
51-
return 0
52-
except KeyboardInterrupt:
53-
print("\n🛑 Server stopped by user")
54-
return 0
55-
except Exception as e:
56-
print(f"❌ Error starting server: {e}")
57-
return 1
118+
max_retries = 2
119+
for attempt in range(max_retries):
120+
try:
121+
serve_logs(
122+
port=args.port,
123+
elasticsearch_config=elasticsearch_config,
124+
debug=args.debug,
125+
backend="fireworks" if use_fireworks else "elasticsearch",
126+
fireworks_base_url=fireworks_base_url if use_fireworks else None,
127+
)
128+
return 0
129+
except KeyboardInterrupt:
130+
print("\n🛑 Server stopped by user")
131+
return 0
132+
except Exception as e:
133+
is_corruption, db_path = _is_database_corruption_error(e)
134+
135+
if is_corruption and db_path and attempt < max_retries - 1:
136+
if _handle_database_corruption(db_path):
137+
# User chose to fix, retry
138+
continue
139+
else:
140+
# User declined fix
141+
return 1
142+
143+
print(f"❌ Error starting server: {e}")
144+
return 1
145+
146+
return 1

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import os
22
from typing import List, Optional
33

4-
from peewee import CharField, Model, SqliteDatabase
4+
from peewee import CharField, DatabaseError, Model, SqliteDatabase
55
from playhouse.sqlite_ext import JSONField
66

7+
from eval_protocol.event_bus.sqlite_event_bus_database import (
8+
SQLITE_HARDENED_PRAGMAS,
9+
DatabaseCorruptedError,
10+
check_and_repair_database,
11+
)
712
from eval_protocol.models import EvaluationRow
813

914

@@ -12,12 +17,20 @@ class SqliteEvaluationRowStore:
1217
Lightweight reusable SQLite store for evaluation rows.
1318
1419
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
20+
Uses hardened SQLite settings for concurrency safety.
1521
"""
1622

17-
def __init__(self, db_path: str):
18-
os.makedirs(os.path.dirname(db_path), exist_ok=True)
23+
def __init__(self, db_path: str, auto_repair: bool = True):
24+
db_dir = os.path.dirname(db_path)
25+
if db_dir:
26+
os.makedirs(db_dir, exist_ok=True)
1927
self._db_path = db_path
20-
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})
28+
29+
# Check and optionally repair corrupted database
30+
check_and_repair_database(db_path, auto_repair=auto_repair)
31+
32+
# Use hardened pragmas for concurrency safety
33+
self._db = SqliteDatabase(self._db_path, pragmas=SQLITE_HARDENED_PRAGMAS)
2134

2235
class BaseModel(Model):
2336
class Meta:

eval_protocol/event_bus/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Global event bus instance - uses SqliteEventBus for cross-process functionality
22
from typing import Any, Callable
33
from eval_protocol.event_bus.event_bus import EventBus
4+
from eval_protocol.event_bus.sqlite_event_bus_database import (
5+
DatabaseCorruptedError,
6+
check_and_repair_database,
7+
SQLITE_HARDENED_PRAGMAS,
8+
)
49

510

611
def _get_default_event_bus():

0 commit comments

Comments
 (0)