-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenv.py
More file actions
156 lines (136 loc) · 5.4 KB
/
env.py
File metadata and controls
156 lines (136 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
env.py — SQLOptimEnv: Core OpenEnv Environment Class
"""
from typing import Any, Dict, Optional
from executor import get_executor
from graders import grade
from leaderboard import record as lb_record
from models import (
Action,
EnvironmentState,
Observation,
Reward,
StepResult,
)
from tasks import TASKS
class SQLOptimEnv:
"""
OpenEnv-compliant environment for SQL Query Optimization.
The agent receives a SQL query + schema context, emits an Action
containing a list of optimization suggestions AND a rewritten
optimized_query. The environment executes both queries against
real DuckDB data, measures the actual speedup, and checks
result correctness — all fed into the reward function.
Multi-step:
• issues_found_so_far accumulates flagged issue types.
• last_execution carries execution metrics back to the agent
so it can refine the optimized_query in subsequent steps.
"""
def __init__(self) -> None:
self._task_data: Optional[Dict[str, Any]] = None
self._step_count: int = 0
self._done: bool = False
self._cumulative_reward: float = 0.0
self._issues_found: list = []
self._last_execution: Optional[Dict[str, Any]] = None
# ── OpenEnv interface ─────────────────────────────────────────────
def reset(
self, task_id: str = "task_1_basic_antipatterns"
) -> Observation:
if task_id not in TASKS:
raise ValueError(
f"Unknown task_id '{task_id}'. "
f"Valid: {list(TASKS.keys())}"
)
self._task_data = TASKS[task_id]
self._step_count = 0
self._done = False
self._cumulative_reward = 0.0
self._issues_found = []
self._last_execution = None
return self._make_obs()
def step(self, action: Action) -> StepResult:
if self._task_data is None:
raise RuntimeError("No active episode — call reset() first.")
if self._done:
raise RuntimeError("Episode finished — call reset() to start a new one.")
self._step_count += 1
# Grade (runs DuckDB internally)
reward: Reward = grade(self._task_data, action)
self._cumulative_reward += reward.score
# Extract execution info from grader feedback for next obs
opt_q = (action.optimized_query or "").strip()
if opt_q:
try:
ex = get_executor()
self._last_execution = ex.compare(
self._task_data["sql_query"], opt_q
)
except Exception:
self._last_execution = None
# Track issue types for progressive context
for s in action.suggestions:
itype = s.get("issue_type", "")
if itype and itype not in self._issues_found:
self._issues_found.append(itype)
max_steps: int = self._task_data["max_steps"]
done = self._step_count >= max_steps or reward.score >= 0.95
self._done = done
# Update leaderboard
speedup = (
self._last_execution.get("speedup", 1.0)
if self._last_execution else 1.0
)
results_match = (
self._last_execution.get("results_match", False)
if self._last_execution else False
)
lb_record(
task_id=self._task_data["task_id"],
speedup=speedup,
score=reward.score,
results_match=results_match,
steps=self._step_count,
)
return StepResult(
observation=self._make_obs(),
reward=reward,
done=done,
info={
"step": self._step_count,
"cumulative_reward": round(self._cumulative_reward, 4),
"issues_found": len(self._issues_found),
"execution": self._last_execution,
},
)
def state(self) -> EnvironmentState:
if self._task_data is None:
return EnvironmentState(
task_id="none", step_count=0, max_steps=0,
episode_done=True, cumulative_reward=0.0,
current_task="No active episode",
)
return EnvironmentState(
task_id=self._task_data["task_id"],
step_count=self._step_count,
max_steps=self._task_data["max_steps"],
episode_done=self._done,
cumulative_reward=round(self._cumulative_reward, 4),
current_task=self._task_data["task_name"],
)
# ── Internal ──────────────────────────────────────────────────────
def _make_obs(self) -> Observation:
d = self._task_data
return Observation(
task_id=d["task_id"],
task_name=d["task_name"],
task_description=d["task_description"],
sql_query=d["sql_query"],
schema_info=d["schema_info"],
dialect=d.get("dialect", "duckdb/postgresql"),
difficulty=d["difficulty"],
step_count=self._step_count,
max_steps=d["max_steps"],
issues_found_so_far=list(self._issues_found),
last_execution=self._last_execution,
)