-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathablation.py
More file actions
98 lines (82 loc) · 3.18 KB
/
ablation.py
File metadata and controls
98 lines (82 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Reward component ablation (no LLM, no API keys).
Runs the deterministic fallback action per task and recomputes the total
score with GradeMask variants to show how much each component contributes.
Usage:
python scripts/ablation.py
python scripts/ablation.py --quick # single task (CI-friendly)
"""
from __future__ import annotations
import argparse
import os
import sys
from collections import defaultdict
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
from baseline_runner import FALLBACK_SOLUTIONS, TASK_IDS # noqa: E402
from graders import GradeMask, grade # noqa: E402
from models import Action # noqa: E402
from tasks import TASKS # noqa: E402
VARIANTS: dict[str, GradeMask] = {
"full": GradeMask(),
"no_execution_speedup": GradeMask(execution_speedup=False),
"no_result_correctness": GradeMask(result_correctness=False),
"no_duckdb_signal": GradeMask(
execution_speedup=False, result_correctness=False
),
"no_issue_detection": GradeMask(issue_detection=False),
"no_approval": GradeMask(approval_correctness=False),
"no_summary": GradeMask(summary_quality=False),
"no_severity": GradeMask(severity_labels=False),
}
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument(
"--quick",
action="store_true",
help="Only task_1 (faster for CI)",
)
args = ap.parse_args()
task_ids = ["task_1_basic_antipatterns"] if args.quick else list(TASK_IDS)
print("SQL-optim-env — reward component ablation (fallback actions)\n")
for task_id in task_ids:
td = TASKS[task_id]
sol = FALLBACK_SOLUTIONS[task_id]
action = Action(
suggestions=sol["suggestions"],
optimized_query=sol["optimized_query"],
summary=sol["summary"],
estimated_improvement=sol["estimated_improvement"],
approved=sol["approved"],
)
full = grade(td, action, mask=None).score
print(f"=== {task_id} ({td['difficulty']}) — full score {full:.4f} ===")
for name, mask in VARIANTS.items():
if name == "full":
continue
s = grade(td, action, mask=mask).score
print(f" {name:24s} score={s:.4f} (Δ {s - full:+.4f})")
print()
acc: dict[str, list[float]] = defaultdict(list)
for task_id in task_ids:
td = TASKS[task_id]
sol = FALLBACK_SOLUTIONS[task_id]
action = Action(
suggestions=sol["suggestions"],
optimized_query=sol["optimized_query"],
summary=sol["summary"],
estimated_improvement=sol["estimated_improvement"],
approved=sol["approved"],
)
for name, mask in VARIANTS.items():
acc[name].append(grade(td, action, mask=mask).score)
print("--- Mean score across all tasks ---")
full_mean = sum(acc["full"]) / len(acc["full"])
for name in VARIANTS:
mean_v = sum(acc[name]) / len(acc[name])
if name == "full":
print(f" {name:24s} {mean_v:.4f}")
else:
print(f" {name:24s} {mean_v:.4f} (Δ {mean_v - full_mean:+.4f} vs full)")
if __name__ == "__main__":
main()