-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathsqlite_evaluation_row_store.py
More file actions
125 lines (98 loc) · 4.75 KB
/
sqlite_evaluation_row_store.py
File metadata and controls
125 lines (98 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
from typing import List, Optional
from peewee import CharField, Model, SqliteDatabase, fn, SQL
from playhouse.sqlite_ext import JSONField
from eval_protocol.event_bus.sqlite_event_bus_database import (
SQLITE_HARDENED_PRAGMAS,
check_and_repair_database,
execute_with_sqlite_retry,
)
from eval_protocol.models import EvaluationRow
class SqliteEvaluationRowStore:
"""
Lightweight reusable SQLite store for evaluation rows.
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
Uses hardened SQLite settings for concurrency safety.
"""
def __init__(self, db_path: str, auto_repair: bool = True):
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db_path = db_path
# Check and optionally repair corrupted database
check_and_repair_database(db_path, auto_repair=auto_repair)
# Use hardened pragmas for concurrency safety
self._db = SqliteDatabase(self._db_path, pragmas=SQLITE_HARDENED_PRAGMAS)
class BaseModel(Model):
class Meta:
database = self._db
class EvaluationRow(BaseModel): # type: ignore
rollout_id = CharField(unique=True)
data = JSONField()
self._EvaluationRow = EvaluationRow
self._db.connect()
# Use safe=True to avoid errors when tables/indexes already exist
self._db.create_tables([EvaluationRow], safe=True)
@property
def db_path(self) -> str:
return self._db_path
def upsert_row(self, data: dict) -> None:
rollout_id = data["execution_metadata"]["rollout_id"]
if rollout_id is None:
raise ValueError("execution_metadata.rollout_id is required to upsert a row")
execute_with_sqlite_retry(lambda: self._do_upsert(rollout_id, data))
def _do_upsert(self, rollout_id: str, data: dict) -> None:
"""Internal method to perform the actual upsert within a transaction."""
# Use IMMEDIATE instead of EXCLUSIVE for better concurrency
# IMMEDIATE acquires a reserved lock immediately but allows concurrent reads
with self._db.atomic("IMMEDIATE"):
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
else:
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
def read_rows(
self,
rollout_id: Optional[str] = None,
invocation_ids: Optional[List[str]] = None,
limit: Optional[int] = None,
) -> List[dict]:
"""
Read evaluation rows from the database with optional filtering.
Args:
rollout_id: Filter by a specific rollout_id (exact match)
invocation_ids: Filter by a list of invocation_ids (rows matching any)
limit: Maximum number of rows to return (most recent first)
Returns:
List of evaluation row data dictionaries
"""
query = self._EvaluationRow.select()
if rollout_id is not None:
query = query.where(self._EvaluationRow.rollout_id == rollout_id)
# Apply invocation_ids filter using JSON extraction
# Note: This filters rows where data->'execution_metadata'->>'invocation_id' matches any of the provided IDs
if invocation_ids is not None and len(invocation_ids) > 0:
# Build a condition that matches any of the invocation_ids
# Using SQLite JSON extraction: json_extract(data, '$.execution_metadata.invocation_id')
invocation_conditions = []
for inv_id in invocation_ids:
invocation_conditions.append(
fn.json_extract(self._EvaluationRow.data, "$.execution_metadata.invocation_id") == inv_id
)
# Combine with OR
if len(invocation_conditions) == 1:
query = query.where(invocation_conditions[0])
else:
from functools import reduce
from operator import or_
combined_condition = reduce(or_, invocation_conditions)
query = query.where(combined_condition)
# Order by rowid descending to get most recent rows first
query = query.order_by(SQL("rowid DESC"))
if limit is not None:
query = query.limit(limit)
results = list(query.dicts())
return [result["data"] for result in results]
def delete_row(self, rollout_id: str) -> int:
return self._EvaluationRow.delete().where(self._EvaluationRow.rollout_id == rollout_id).execute()
def delete_all_rows(self) -> int:
return self._EvaluationRow.delete().execute()