Skip to content

Commit 0e60f2a

Browse files
committed
code refactor + batch
1 parent c20f33c commit 0e60f2a

4 files changed

Lines changed: 432 additions & 344 deletions

File tree

src/understudy/check.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Check: validate a trace against scene expectations."""
22

3-
from concurrent.futures import ThreadPoolExecutor, as_completed
43
from dataclasses import dataclass, field
54
from pathlib import Path
65
from typing import Any
76

7+
from .batch import BatchExecutor
88
from .metrics import MetricRegistry, MetricResult
99
from .models import Expectations
1010
from .trace import Trace
@@ -214,6 +214,51 @@ def passed(self) -> bool:
214214
return self.error is None and self.check_result.passed
215215

216216

217+
@dataclass
218+
class _EvaluationTask:
219+
"""Internal task descriptor for batch evaluation."""
220+
221+
trace_id: str
222+
trace: Trace
223+
expectations: Expectations
224+
225+
226+
class _EvaluationExecutor(BatchExecutor[_EvaluationTask, EvaluationResult]):
227+
"""Executor for running evaluations in parallel."""
228+
229+
def __init__(
230+
self,
231+
metrics: list[str] | None,
232+
judge_model: str | None,
233+
result_storage: Any,
234+
parallel: int = 1,
235+
):
236+
super().__init__(parallel)
237+
self.metrics = metrics
238+
self.judge_model = judge_model
239+
self.result_storage = result_storage
240+
241+
def execute_one(self, item: _EvaluationTask) -> EvaluationResult:
242+
try:
243+
check_result = evaluate(
244+
trace=item.trace,
245+
expectations=item.expectations,
246+
metrics=self.metrics,
247+
judge_model=self.judge_model,
248+
)
249+
if self.result_storage:
250+
self.result_storage.save(trace_id=item.trace_id, check_result=check_result)
251+
return EvaluationResult(trace_id=item.trace_id, check_result=check_result)
252+
except Exception as e:
253+
import traceback
254+
255+
return EvaluationResult(
256+
trace_id=item.trace_id,
257+
check_result=CheckResult(),
258+
error=f"{type(e).__name__}: {e}\n{traceback.format_exc()}",
259+
)
260+
261+
217262
def evaluate_batch(
218263
traces: list[Trace] | str | Path,
219264
expectations: Expectations | None = None,
@@ -237,7 +282,7 @@ def evaluate_batch(
237282
"""
238283
from .storage import EvaluationStorage, TraceStorage
239284

240-
trace_list: list[tuple[str, Trace, Expectations]] = []
285+
tasks: list[_EvaluationTask] = []
241286

242287
if isinstance(traces, (str, Path)):
243288
path = Path(traces)
@@ -247,50 +292,22 @@ def evaluate_batch(
247292
trace = data["trace"]
248293
scene = data["scene"]
249294
exp = expectations if expectations else scene.expectations
250-
trace_list.append((trace_id, trace, exp))
295+
tasks.append(_EvaluationTask(trace_id=trace_id, trace=trace, expectations=exp))
251296
else:
252297
for i, trace in enumerate(traces):
253298
trace_id = f"trace_{i}"
254299
exp = expectations if expectations else Expectations()
255-
trace_list.append((trace_id, trace, exp))
300+
tasks.append(_EvaluationTask(trace_id=trace_id, trace=trace, expectations=exp))
256301

257302
result_storage = None
258303
if output:
259304
result_storage = EvaluationStorage(path=Path(output))
260305

261-
results: list[EvaluationResult] = []
262-
263-
def evaluate_single(trace_id: str, trace: Trace, exp: Expectations) -> EvaluationResult:
264-
try:
265-
check_result = evaluate(
266-
trace=trace,
267-
expectations=exp,
268-
metrics=metrics,
269-
judge_model=judge_model,
270-
)
271-
if result_storage:
272-
result_storage.save(trace_id=trace_id, check_result=check_result)
273-
return EvaluationResult(trace_id=trace_id, check_result=check_result)
274-
except Exception as e:
275-
import traceback
306+
executor = _EvaluationExecutor(
307+
metrics=metrics,
308+
judge_model=judge_model,
309+
result_storage=result_storage,
310+
parallel=parallel,
311+
)
276312

277-
return EvaluationResult(
278-
trace_id=trace_id,
279-
check_result=CheckResult(),
280-
error=f"{type(e).__name__}: {e}\n{traceback.format_exc()}",
281-
)
282-
283-
if parallel <= 1:
284-
for trace_id, trace, exp in trace_list:
285-
result = evaluate_single(trace_id, trace, exp)
286-
results.append(result)
287-
else:
288-
with ThreadPoolExecutor(max_workers=parallel) as executor:
289-
futures = {
290-
executor.submit(evaluate_single, trace_id, trace, exp): trace_id
291-
for trace_id, trace, exp in trace_list
292-
}
293-
for future in as_completed(futures):
294-
results.append(future.result())
295-
296-
return results
313+
return executor.run(tasks)

src/understudy/runner.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import logging
44
import time
5-
from concurrent.futures import ThreadPoolExecutor, as_completed
65
from datetime import UTC, datetime
76
from pathlib import Path
87
from typing import TYPE_CHECKING, Any, Protocol
98

9+
from .batch import BatchExecutor
1010
from .mocks import MockToolkit
1111
from .models import Scene
1212
from .trace import StateSnapshot, ToolCall, Trace, Turn, TurnMetrics
@@ -253,6 +253,50 @@ def simulate(
253253
)
254254

255255

256+
class _SimulationTask:
257+
"""Internal task descriptor for batch simulation."""
258+
259+
def __init__(self, scene: Scene, sim_index: int):
260+
self.scene = scene
261+
self.sim_index = sim_index
262+
263+
264+
class _SimulationExecutor(BatchExecutor[_SimulationTask, Trace]):
265+
"""Executor for running simulations in parallel."""
266+
267+
def __init__(
268+
self,
269+
app: AgentApp,
270+
mocks: MockToolkit | None,
271+
simulator_model: str,
272+
storage: Any,
273+
tags: dict[str, str] | None,
274+
parallel: int = 1,
275+
):
276+
super().__init__(parallel)
277+
self.app = app
278+
self.mocks = mocks
279+
self.simulator_model = simulator_model
280+
self.storage = storage
281+
self.tags = tags
282+
283+
def execute_one(self, item: _SimulationTask) -> Trace:
284+
trace = simulate(
285+
app=self.app,
286+
scene=item.scene,
287+
mocks=self.mocks,
288+
simulator_model=self.simulator_model,
289+
)
290+
if self.storage:
291+
self.storage.save(
292+
trace=trace,
293+
scene=item.scene,
294+
sim_index=item.sim_index,
295+
tags=self.tags,
296+
)
297+
return trace
298+
299+
256300
def simulate_batch(
257301
app: AgentApp,
258302
scenes: list[Scene] | str | Path,
@@ -298,40 +342,17 @@ def simulate_batch(
298342

299343
storage = TraceStorage(path=Path(output))
300344

301-
sim_tasks = []
302-
for scene in scene_list:
303-
for sim_index in range(n_sims):
304-
sim_tasks.append((scene, sim_index))
305-
306-
traces: list[Trace] = []
345+
tasks = [
346+
_SimulationTask(scene, sim_index) for scene in scene_list for sim_index in range(n_sims)
347+
]
307348

308-
def run_single_sim(scene: Scene, sim_index: int) -> Trace:
309-
trace = simulate(
310-
app=app,
311-
scene=scene,
312-
mocks=mocks,
313-
simulator_model=simulator_model,
314-
)
315-
if storage:
316-
storage.save(
317-
trace=trace,
318-
scene=scene,
319-
sim_index=sim_index,
320-
tags=tags,
321-
)
322-
return trace
349+
executor = _SimulationExecutor(
350+
app=app,
351+
mocks=mocks,
352+
simulator_model=simulator_model,
353+
storage=storage,
354+
tags=tags,
355+
parallel=parallel,
356+
)
323357

324-
if parallel <= 1:
325-
for scene, sim_index in sim_tasks:
326-
trace = run_single_sim(scene, sim_index)
327-
traces.append(trace)
328-
else:
329-
with ThreadPoolExecutor(max_workers=parallel) as executor:
330-
futures = {
331-
executor.submit(run_single_sim, scene, sim_index): (scene, sim_index)
332-
for scene, sim_index in sim_tasks
333-
}
334-
for future in as_completed(futures):
335-
traces.append(future.result())
336-
337-
return traces
358+
return executor.run(tasks)

0 commit comments

Comments
 (0)