From 0ba23061f5c4aa4c1787c5e0cdc45c455db43acb Mon Sep 17 00:00:00 2001 From: Sugaria0427 <1768802831@qq.com> Date: Sun, 24 May 2026 20:46:32 +0800 Subject: [PATCH 1/2] feat: add SQLite persistent storage for benchmark results (#26) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces flat JSON/CSV log storage with a queryable SQLite database while maintaining full backward compatibility. New files: - utils/storage.py — Storage class wrapping Python stdlib sqlite3. Schema: runs (run_id, timestamp, config_json) + results (run_id FK, backend, turn, metric, value). API: save_run, get_run, list_runs, compare_runs. - utils/migrate_legacy_logs.py — one-shot idempotent migration script that imports existing experiment_logs/*.json into SQLite. Modified files: - evaluation/logger.py — log_run() now calls Storage().save_run() alongside existing JSON+CSV writes. list_runs() queries SQLite first, falls back to filesystem scan. Fixes pre-existing has_llm_eval bug in _append_csv_summary. - tests/test_pipeline.py — 6 new tests (4 Storage CRUD + 2 logger integration). - CHANGELOG.md — documented the new feature. - .gitignore — added experiment_logs/memorylens.db. Closes #26 --- .gitignore | 1 + CHANGELOG.md | 12 ++ evaluation/logger.py | 36 +++++- tests/test_pipeline.py | 177 +++++++++++++++++++++++++++++ utils/migrate_legacy_logs.py | 75 ++++++++++++ utils/storage.py | 214 +++++++++++++++++++++++++++++++++++ 6 files changed, 511 insertions(+), 4 deletions(-) create mode 100644 utils/migrate_legacy_logs.py create mode 100644 utils/storage.py diff --git a/.gitignore b/.gitignore index 22e75aa..df0943d 100644 --- a/.gitignore +++ b/.gitignore @@ -19,5 +19,6 @@ env/ .DS_Store Thumbs.db results.json +experiment_logs/memorylens.db *.log .streamlit/secrets.toml diff --git a/CHANGELOG.md b/CHANGELOG.md index c30d55f..9a1318d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,18 @@ Format follows [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +### Added + +- SQLite persistent storage (`utils/storage.py`) — queryable database replacing flat JSON/CSV logs +- Migration script (`utils/migrate_legacy_logs.py`) — one-shot import of existing JSON logs into SQLite +- `Storage.compare_runs()` — cross-run recall comparison API +- `log_run()` now writes to SQLite alongside existing JSON/CSV output (backward compatible) +- `list_runs()` queries SQLite first, falls back to filesystem scan + +### Fixed + +- `_append_csv_summary` now properly filters `has_llm_eval` from display_data (pre-existing bug where `has_llm_eval: True` caused `TypeError` when iterating display_data) + ### Added — Research-Grade Fixes (`feat/research-grade-fixes`) **Fix 1 — Multi-seed statistical validation** diff --git a/evaluation/logger.py b/evaluation/logger.py index 6d98307..e23d8db 100644 --- a/evaluation/logger.py +++ b/evaluation/logger.py @@ -9,6 +9,8 @@ from datetime import datetime from typing import Any, Dict, Optional +from utils.storage import Storage + LOG_DIR = os.path.join(os.path.dirname(__file__), "..", "experiment_logs") @@ -20,7 +22,7 @@ def _ensure_dir() -> str: def log_run(display_data: Dict, config: Dict[str, Any], run_id: Optional[str] = None) -> str: """ - Persist a benchmark run to disk. + Persist a benchmark run to disk (JSON + CSV + SQLite). Returns the path to the saved JSON file. """ _ensure_dir() @@ -33,6 +35,14 @@ def log_run(display_data: Dict, config: Dict[str, Any], run_id: Optional[str] = json.dump(payload, fh, indent=2) _append_csv_summary(display_data, config, run_id) + + # ── SQLite persistence (non-blocking on failure) ──────────────────────── + try: + Storage().save_run(run_id, config, display_data) + except Exception as exc: + import warnings + warnings.warn(f"SQLite write failed for run {run_id}: {exc}") + return json_path @@ -41,7 +51,7 @@ def _append_csv_summary(display_data: Dict, config: Dict, run_id: str) -> None: file_exists = os.path.exists(csv_path) checkpoints = display_data.get("checkpoints", []) - backends = [k for k in display_data if k != "checkpoints"] + backends = [k for k in display_data if k not in ("checkpoints", "has_llm_eval")] rows = [] for backend in backends: @@ -67,7 +77,24 @@ def _append_csv_summary(display_data: Dict, config: Dict, run_id: str) -> None: def list_runs() -> list: - """Return metadata for all logged runs, newest first.""" + """Return metadata for all logged runs, newest first. + + Queries SQLite database first; falls back to filesystem scan + when the database doesn't exist yet. + + Returns a list of dicts, each with ``run_id``, ``timestamp``, + and ``config`` keys (unified schema for both storage backends). + """ + # ── Try SQLite first ──────────────────────────────────────────────────── + try: + store = Storage() + runs = store.list_runs(limit=50) + if runs: + return runs + except Exception: + pass + + # ── Fallback: scan filesystem (legacy) ───────────────────────────────── log_dir = _ensure_dir() runs = [] for fname in sorted(os.listdir(log_dir), reverse=True): @@ -77,8 +104,9 @@ def list_runs() -> list: data = json.load(fh) runs.append({ "run_id": data.get("run_id"), + "timestamp": datetime.fromtimestamp(os.path.getmtime(fpath)) + .strftime("%Y-%m-%dT%H:%M:%SZ"), "config": data.get("config", {}), - "path": fpath, }) return runs diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c17c08a..b98fe2c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -287,6 +287,175 @@ def test_persona_pool_structure(): print(f"PASS: persona pool structure ({len(PERSONA_POOL)} personas, {len(expected_keys)} keys each)") +# ── SQLite Storage tests ────────────────────────────────────────────────────── + +def test_storage_save_and_get_run(): + from utils.storage import Storage + import tempfile, os + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + try: + store = Storage(db_path) + display = { + "checkpoints": [10, 25], + "naive": { + "recall": [1.0, 0.8], "precision": [0.9, 0.7], + "drift": [0.0, 0.1], "noise": [0.5, 0.8], + "tokens": [100, 500], + } + } + store.save_run("test_run", {"total_turns": 25, "backends": ["naive"]}, display) + loaded = store.get_run("test_run") + assert loaded is not None + assert loaded["checkpoints"] == [10, 25] + assert loaded["naive"]["recall"] == [1.0, 0.8] + assert loaded["naive"]["tokens"] == [100, 500] + finally: + store.close() + os.unlink(db_path) + print("PASS: storage save and get run") + + +def test_storage_list_runs(): + from utils.storage import Storage + import tempfile, os + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + try: + store = Storage(db_path) + display = {"checkpoints": [10], "naive": {"recall": [0.5], "precision": [0.5], + "drift": [0], "noise": [0], "tokens": [100]}} + store.save_run("run_b", {"total_turns": 10}, display) + store.save_run("run_a", {"total_turns": 10}, display) + runs = store.list_runs(limit=10) + assert len(runs) >= 2 + ids = [r["run_id"] for r in runs] + assert "run_a" in ids and "run_b" in ids, f"Missing runs in {ids}" + finally: + store.close() + os.unlink(db_path) + print("PASS: storage list runs") + + +def test_storage_compare_runs(): + from utils.storage import Storage + import tempfile, os + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + try: + store = Storage(db_path) + display = {"checkpoints": [10], "naive": {"recall": [0.8], "precision": [0.8], + "drift": [0], "noise": [0], "tokens": [100]}} + store.save_run("run_a", {}, display) + display["naive"]["recall"] = [0.9] + store.save_run("run_b", {}, display) + comp = store.compare_runs("run_a", "run_b") + assert comp["run_a"]["run_id"] == "run_a" + assert comp["run_b"]["run_id"] == "run_b" + assert comp["run_a"]["backends"]["naive"] == [0.8] + assert comp["run_b"]["backends"]["naive"] == [0.9] + finally: + store.close() + os.unlink(db_path) + print("PASS: storage compare runs") + + +def test_storage_get_run_not_found(): + from utils.storage import Storage + import tempfile, os + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + try: + store = Storage(db_path) + assert store.get_run("nonexistent") is None + finally: + store.close() + os.unlink(db_path) + print("PASS: storage get_run returns None for missing run") + + +# ── Logger + SQLite integration tests ──────────────────────────────────────── + +def _clean_csv_row(run_id: str) -> None: + """Remove a test run_id from runs_summary.csv to avoid accumulation.""" + import csv + csv_path = os.path.join( + os.path.dirname(__file__), "..", "experiment_logs", "runs_summary.csv" + ) + if not os.path.exists(csv_path): + return + rows = [] + with open(csv_path, newline="") as fh: + reader = csv.DictReader(fh) + for row in reader: + if row.get("run_id") != run_id: + rows.append(row) + if rows: + with open(csv_path, "w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=rows[0].keys()) + writer.writeheader() + writer.writerows(rows) + else: + os.unlink(csv_path) + + +def test_logger_writes_sqlite(): + """log_run must write to SQLite, not just JSON.""" + from evaluation.logger import log_run + from utils.storage import Storage + import os + + display = { + "checkpoints": [10], + "naive": {"recall": [0.5], "precision": [0.5], + "drift": [0], "noise": [0], "tokens": [100]}, + } + config = {"total_turns": 10, "backends": ["naive"]} + + run_id = "_test_sqlite_logger" + json_path = log_run(display, config, run_id=run_id) + assert os.path.exists(json_path), "JSON file must exist (backward compat)" + + # Verify SQLite has the data + store = Storage() + loaded = store.get_run(run_id) + assert loaded is not None, "SQLite must contain the run" + assert loaded["naive"]["recall"] == [0.5] + + # Cleanup test artifacts (JSON, SQLite, CSV) + os.unlink(json_path) + store.conn.execute("DELETE FROM results WHERE run_id = ?", (run_id,)) + store.conn.execute("DELETE FROM runs WHERE run_id = ?", (run_id,)) + store.conn.commit() + _clean_csv_row(run_id) + print("PASS: logger writes to SQLite") + + +def test_list_runs_returns_sqlite_runs(): + """list_runs must return SQLite-backed runs, not just filesystem scans.""" + from evaluation.logger import list_runs + from utils.storage import Storage + + store = Storage() + display = {"checkpoints": [10], "naive": {"recall": [0.6], "precision": [0.6], + "drift": [0], "noise": [0], "tokens": [100]}} + store.save_run("_test_list_runs", {"total_turns": 10}, display) + + runs = list_runs() + ids = [r["run_id"] for r in runs] + assert "_test_list_runs" in ids, "list_runs must include SQLite runs" + + # Cleanup + store.conn.execute("DELETE FROM results WHERE run_id = ?", ("_test_list_runs",)) + store.conn.execute("DELETE FROM runs WHERE run_id = ?", ("_test_list_runs",)) + store.conn.commit() + print("PASS: list_runs returns SQLite runs") + + if __name__ == "__main__": tests = [ test_conversation_generator, @@ -316,6 +485,14 @@ def test_persona_pool_structure(): # Stats / multi-seed test_stats_aggregate_metric, test_persona_pool_structure, + # SQLite Storage + test_storage_save_and_get_run, + test_storage_list_runs, + test_storage_compare_runs, + test_storage_get_run_not_found, + # Logger + SQLite integration + test_logger_writes_sqlite, + test_list_runs_returns_sqlite_runs, ] failed = 0 for t in tests: diff --git a/utils/migrate_legacy_logs.py b/utils/migrate_legacy_logs.py new file mode 100644 index 0000000..f033f0f --- /dev/null +++ b/utils/migrate_legacy_logs.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +""" +One-shot migration: import existing JSON log files into SQLite. + +Usage: + python utils/migrate_legacy_logs.py + +Scans experiment_logs/*.json, parses each file, and inserts +into the SQLite database at experiment_logs/memorylens.db. + +Safe to run multiple times — skips already-migrated run_ids. +""" + +import json +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from utils.storage import Storage + + +LOG_DIR = os.path.join(os.path.dirname(__file__), "..", "experiment_logs") + + +def migrate() -> int: + """Import legacy JSON logs into SQLite. Returns count of migrated runs.""" + store = Storage() + count = 0 + + if not os.path.isdir(LOG_DIR): + print(f"Log directory not found: {LOG_DIR}") + return 0 + + existing_ids = { + row["run_id"] + for row in store.conn.execute("SELECT run_id FROM runs").fetchall() + } + + for fname in sorted(os.listdir(LOG_DIR)): + if not fname.endswith(".json") or fname == "runs_summary.csv": + continue + + fpath = os.path.join(LOG_DIR, fname) + with open(fpath) as fh: + try: + data = json.load(fh) + except (json.JSONDecodeError, OSError) as exc: + print(f" SKIP {fname}: {exc}") + continue + + run_id = data.get("run_id") + if not run_id: + print(f" SKIP {fname}: no run_id") + continue + + if run_id in existing_ids: + print(f" SKIP {fname}: already migrated") + continue + + config = data.get("config", {}) + results = data.get("results", data.get("display_data", {})) + + store.save_run(run_id, config, results) + print(f" OK {fname} -> run_id={run_id}") + count += 1 + + return count + + +if __name__ == "__main__": + print("Migrating legacy JSON logs to SQLite...") + total = migrate() + print(f"Done. {total} run(s) migrated.") + sys.exit(0 if total >= 0 else 1) diff --git a/utils/storage.py b/utils/storage.py new file mode 100644 index 0000000..569a0c9 --- /dev/null +++ b/utils/storage.py @@ -0,0 +1,214 @@ +""" +SQLite-backed persistent storage for benchmark runs. + +Replaces flat JSON/CSV file I/O with a queryable relational store. +Python stdlib only — zero additional dependencies. + +Schema +------ +runs: run_id TEXT PK, timestamp TEXT, config_json TEXT +results: id INTEGER PK AUTOINCREMENT, run_id TEXT FK, backend TEXT, + turn INTEGER, metric TEXT, value REAL +""" + +import json +import os +import sqlite3 +from typing import Any, Dict, List, Optional + + +_LOG_DIR = os.path.join(os.path.dirname(__file__), "..", "experiment_logs") + + +class Storage: + """SQLite-backed storage for MemoryLens benchmark runs.""" + + def __init__(self, db_path: Optional[str] = None): + if db_path is None: + os.makedirs(_LOG_DIR, exist_ok=True) + db_path = os.path.join(_LOG_DIR, "memorylens.db") + self._db_path = db_path + self._conn: Optional[sqlite3.Connection] = None + + # ── Connection management ────────────────────────────────────────────── + + @property + def conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(self._db_path) + self._conn.row_factory = sqlite3.Row + self._ensure_schema() + return self._conn + + def close(self) -> None: + if self._conn is not None: + self._conn.close() + self._conn = None + + def _ensure_schema(self) -> None: + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS runs ( + run_id TEXT PRIMARY KEY, + timestamp TEXT NOT NULL, + config_json TEXT NOT NULL + ) + """) + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL REFERENCES runs(run_id), + backend TEXT NOT NULL, + turn INTEGER NOT NULL, + metric TEXT NOT NULL, + value REAL + ) + """) + self._conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_results_run_id ON results(run_id) + """) + self._conn.commit() + + # ── Write ────────────────────────────────────────────────────────────── + + def save_run(self, run_id: str, config: Dict[str, Any], display_data: Dict[str, Any]) -> None: + """Insert a benchmark run into the database.""" + # Capture per-backend metadata (provider, decay) into config so + # get_run() can faithfully reconstruct the original dict. + config = dict(config) + backend_meta = {} + for k in display_data: + if k in ("checkpoints", "has_llm_eval"): + continue + d = display_data[k] + meta = {} + if "provider" in d: + meta["provider"] = d["provider"] + if "decay" in d: + meta["decay"] = d["decay"] + if meta: + backend_meta[k] = meta + if backend_meta: + config["_backend_meta"] = backend_meta + + with self.conn: + self.conn.execute( + "INSERT OR REPLACE INTO runs (run_id, timestamp, config_json) VALUES (?, ?, ?)", + (run_id, _now_iso(), json.dumps(config)), + ) + + # Unfold display_data into metric rows + checkpoints: List[int] = display_data.get("checkpoints", []) + backends = [k for k in display_data if k != "checkpoints" and k != "has_llm_eval"] + + rows: List[tuple] = [] + metric_keys = ["recall", "precision", "drift", "noise", "tokens", + "cascade_eff", "llm_recall", "llm_drift"] + for backend in backends: + backend_data = display_data[backend] + for i, cp in enumerate(checkpoints): + for metric in metric_keys: + values = backend_data.get(metric, []) + if i < len(values): + v = values[i] + if v is not None: + rows.append((run_id, backend, cp, metric, float(v))) + + self.conn.executemany( + "INSERT INTO results (run_id, backend, turn, metric, value) VALUES (?, ?, ?, ?, ?)", + rows, + ) + + # ── Read ─────────────────────────────────────────────────────────────── + + def get_run(self, run_id: str) -> Optional[Dict[str, Any]]: + """Return the full display_data dict for a run, or None if not found.""" + cur = self.conn.execute( + "SELECT config_json FROM runs WHERE run_id = ?", (run_id,) + ) + row = cur.fetchone() + if row is None: + return None + + config = json.loads(row["config_json"]) + + cur = self.conn.execute( + "SELECT backend, turn, metric, value FROM results WHERE run_id = ? ORDER BY turn", + (run_id,) + ) + rows = cur.fetchall() + + # Group by backend, then by turn + backends: Dict[str, Dict[int, Dict[str, float]]] = {} + checkpoints: set = set() + for r in rows: + backend, turn, metric, value = r["backend"], r["turn"], r["metric"], r["value"] + checkpoints.add(turn) + backends.setdefault(backend, {}).setdefault(turn, {})[metric] = value + + cps = sorted(checkpoints) + display: Dict[str, Any] = {"checkpoints": cps, "has_llm_eval": False} + + metric_keys = ["recall", "precision", "drift", "noise", "tokens", + "cascade_eff", "llm_recall", "llm_drift"] + + backend_meta = config.get("_backend_meta", {}) + + for backend, turn_map in backends.items(): + display[backend] = {} + for metric in metric_keys: + display[backend][metric] = [ + turn_map[t].get(metric) if metric in turn_map.get(t, {}) else None + for t in cps + ] + if any( + v is not None + for m in ["llm_recall", "llm_drift"] + for v in display[backend].get(m, []) + ): + display["has_llm_eval"] = True + + # Restore scalar metadata fields (provider, decay) + meta = backend_meta.get(backend, {}) + if "provider" in meta: + display[backend]["provider"] = meta["provider"] + if "decay" in meta: + display[backend]["decay"] = meta["decay"] + + return display + + def list_runs(self, limit: int = 50) -> List[Dict[str, Any]]: + """Return run metadata, newest first.""" + cur = self.conn.execute( + "SELECT run_id, timestamp, config_json FROM runs ORDER BY timestamp DESC LIMIT ?", + (limit,) + ) + runs = [] + for row in cur: + runs.append({ + "run_id": row["run_id"], + "timestamp": row["timestamp"], + "config": json.loads(row["config_json"]), + }) + return runs + + def compare_runs(self, run_id_a: str, run_id_b: str) -> Dict[str, Any]: + """Return recall arrays for two runs, keyed by backend and run_id.""" + def _extract_recall(rid: str) -> Dict[str, List[float]]: + data = self.get_run(rid) + if data is None: + return {} + return { + k: list(v["recall"]) + for k, v in data.items() + if k not in ("checkpoints", "has_llm_eval") + } + + return { + "run_a": {"run_id": run_id_a, "backends": _extract_recall(run_id_a)}, + "run_b": {"run_id": run_id_b, "backends": _extract_recall(run_id_b)}, + } + + +def _now_iso() -> str: + from datetime import datetime, timezone + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") From 6fcc8db858fe73783bd2f0e880b113c63c668531 Mon Sep 17 00:00:00 2001 From: Sugaria0427 <1768802831@qq.com> Date: Wed, 3 Jun 2026 23:27:48 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20address=20PR=20#34=20review=20?= =?UTF-8?q?=E2=80=94=203=20critical=20bugs=20+=202=20improvements?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Bug 1: DELETE stale results before INSERT in save_run() to prevent duplicate rows on repeated run_id (fixes get_run() corruption) - Bug 2: Add empty-rows guard in _append_csv_summary() to prevent IndexError - Bug 3: Close Storage() connections in log_run() and list_runs() via try/finally to prevent file handle leaks on Windows - Issue 4: Document None-handling contract in get_run() docstring; filter None from compare_runs() recall arrays - Issue 5: Close Storage() in test cleanup to match the Storage unit tests Adds test_storage_save_run_idempotent to verify fix for Bug 1. --- evaluation/logger.py | 14 +++++++++++++- tests/test_pipeline.py | 26 ++++++++++++++++++++++++++ utils/storage.py | 31 ++++++++++++++++++++++++------- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/evaluation/logger.py b/evaluation/logger.py index e23d8db..0bed980 100644 --- a/evaluation/logger.py +++ b/evaluation/logger.py @@ -37,11 +37,16 @@ def log_run(display_data: Dict, config: Dict[str, Any], run_id: Optional[str] = _append_csv_summary(display_data, config, run_id) # ── SQLite persistence (non-blocking on failure) ──────────────────────── + store: Optional[Storage] = None try: - Storage().save_run(run_id, config, display_data) + store = Storage() + store.save_run(run_id, config, display_data) except Exception as exc: import warnings warnings.warn(f"SQLite write failed for run {run_id}: {exc}") + finally: + if store is not None: + store.close() return json_path @@ -69,6 +74,9 @@ def _append_csv_summary(display_data: Dict, config: Dict, run_id: str) -> None: "total_turns": config.get("total_turns", ""), }) + if not rows: + return + with open(csv_path, "a", newline="") as fh: writer = csv.DictWriter(fh, fieldnames=rows[0].keys()) if not file_exists: @@ -86,6 +94,7 @@ def list_runs() -> list: and ``config`` keys (unified schema for both storage backends). """ # ── Try SQLite first ──────────────────────────────────────────────────── + store: Optional[Storage] = None try: store = Storage() runs = store.list_runs(limit=50) @@ -93,6 +102,9 @@ def list_runs() -> list: return runs except Exception: pass + finally: + if store is not None: + store.close() # ── Fallback: scan filesystem (legacy) ───────────────────────────────── log_dir = _ensure_dir() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b98fe2c..b0ae013 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -378,6 +378,29 @@ def test_storage_get_run_not_found(): print("PASS: storage get_run returns None for missing run") +def test_storage_save_run_idempotent(): + """Calling save_run twice with the same run_id must not duplicate rows.""" + from utils.storage import Storage + import tempfile, os + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + try: + store = Storage(db_path) + display = {"checkpoints": [10], "naive": {"recall": [0.8], "precision": [0.8], + "drift": [0], "noise": [0], "tokens": [100]}} + store.save_run("dup_test", {}, display) + store.save_run("dup_test", {}, display) # same run_id again + loaded = store.get_run("dup_test") + assert loaded is not None + assert len(loaded["naive"]["recall"]) == 1, "Duplicate rows detected!" + assert loaded["naive"]["recall"] == [0.8] + finally: + store.close() + os.unlink(db_path) + print("PASS: storage save_run idempotent") + + # ── Logger + SQLite integration tests ──────────────────────────────────────── def _clean_csv_row(run_id: str) -> None: @@ -431,6 +454,7 @@ def test_logger_writes_sqlite(): store.conn.execute("DELETE FROM results WHERE run_id = ?", (run_id,)) store.conn.execute("DELETE FROM runs WHERE run_id = ?", (run_id,)) store.conn.commit() + store.close() _clean_csv_row(run_id) print("PASS: logger writes to SQLite") @@ -453,6 +477,7 @@ def test_list_runs_returns_sqlite_runs(): store.conn.execute("DELETE FROM results WHERE run_id = ?", ("_test_list_runs",)) store.conn.execute("DELETE FROM runs WHERE run_id = ?", ("_test_list_runs",)) store.conn.commit() + store.close() print("PASS: list_runs returns SQLite runs") @@ -490,6 +515,7 @@ def test_list_runs_returns_sqlite_runs(): test_storage_list_runs, test_storage_compare_runs, test_storage_get_run_not_found, + test_storage_save_run_idempotent, # Logger + SQLite integration test_logger_writes_sqlite, test_list_runs_returns_sqlite_runs, diff --git a/utils/storage.py b/utils/storage.py index 569a0c9..9237c81 100644 --- a/utils/storage.py +++ b/utils/storage.py @@ -96,6 +96,11 @@ def save_run(self, run_id: str, config: Dict[str, Any], display_data: Dict[str, (run_id, _now_iso(), json.dumps(config)), ) + # Delete stale results before inserting fresh ones to prevent + # duplicate rows on repeated run_id (INSERT OR REPLACE on runs + # does not cascade to results). + self.conn.execute("DELETE FROM results WHERE run_id = ?", (run_id,)) + # Unfold display_data into metric rows checkpoints: List[int] = display_data.get("checkpoints", []) backends = [k for k in display_data if k != "checkpoints" and k != "has_llm_eval"] @@ -121,7 +126,14 @@ def save_run(self, run_id: str, config: Dict[str, Any], display_data: Dict[str, # ── Read ─────────────────────────────────────────────────────────────── def get_run(self, run_id: str) -> Optional[Dict[str, Any]]: - """Return the full display_data dict for a run, or None if not found.""" + """Return the full display_data dict for a run, or None if not found. + + Notes + ----- + Metric arrays may contain ``None`` for checkpoints where the metric + was not collected (e.g. ``llm_recall`` when LLM eval was off). + Callers should handle ``None`` before arithmetic. + """ cur = self.conn.execute( "SELECT config_json FROM runs WHERE run_id = ?", (run_id,) ) @@ -192,16 +204,21 @@ def list_runs(self, limit: int = 50) -> List[Dict[str, Any]]: return runs def compare_runs(self, run_id_a: str, run_id_b: str) -> Dict[str, Any]: - """Return recall arrays for two runs, keyed by backend and run_id.""" + """Return recall arrays for two runs, keyed by backend and run_id. + + ``None`` values (missing checkpoints) are filtered out so + downstream arithmetic / plotting consumers receive clean arrays. + """ def _extract_recall(rid: str) -> Dict[str, List[float]]: data = self.get_run(rid) if data is None: return {} - return { - k: list(v["recall"]) - for k, v in data.items() - if k not in ("checkpoints", "has_llm_eval") - } + result: Dict[str, List[float]] = {} + for k, v in data.items(): + if k in ("checkpoints", "has_llm_eval"): + continue + result[k] = [x for x in v["recall"] if x is not None] + return result return { "run_a": {"run_id": run_id_a, "backends": _extract_recall(run_id_a)},