Skip to content

Commit e8675b9

Browse files
committed
improve code
1 parent 866c727 commit e8675b9

7 files changed

Lines changed: 647 additions & 13 deletions

File tree

src/understudy/check.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,20 @@ def evaluate(
160160
judge_model: str | None = None,
161161
judges: dict[str, Any] | None = None,
162162
) -> CheckResult:
163-
"""Evaluate a trace against expectations.
163+
"""Evaluate a trace against expectations with optional judge evaluations.
164164
165-
This is an alias for check() with clearer naming for the evaluate workflow.
165+
Extends check() with support for overriding metrics and adding LLM judge
166+
evaluations.
166167
167168
Args:
168169
trace: The execution trace to evaluate.
169170
expectations: The expectations to check against.
170171
metrics: Override metrics to compute (defaults to scene's metrics).
171172
judge_model: Optional LLM model for judge evaluations.
172-
judges: Optional pre-configured judge objects.
173+
judges: Optional pre-configured judge objects (dict of name -> Judge).
173174
174175
Returns:
175-
A CheckResult with check outcomes and metrics.
176+
A CheckResult with check outcomes, metrics, and judge results.
176177
"""
177178
if metrics:
178179
expectations = Expectations(
@@ -271,10 +272,12 @@ def evaluate_single(trace_id: str, trace: Trace, exp: Expectations) -> Evaluatio
271272
result_storage.save(trace_id=trace_id, check_result=check_result)
272273
return EvaluationResult(trace_id=trace_id, check_result=check_result)
273274
except Exception as e:
275+
import traceback
276+
274277
return EvaluationResult(
275278
trace_id=trace_id,
276279
check_result=CheckResult(),
277-
error=str(e),
280+
error=f"{type(e).__name__}: {e}\n{traceback.format_exc()}",
278281
)
279282

280283
if parallel <= 1:

src/understudy/langgraph/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def send(self, message: str) -> AgentResponse:
137137
state_values = graph_state.values
138138
if isinstance(state_values, dict):
139139
state_snapshot = {k: v for k, v in state_values.items() if k != "messages"}
140-
except Exception:
141-
pass
140+
except Exception as e:
141+
logger.debug("Failed to capture graph state: %s", e)
142142

143143
return AgentResponse(
144144
content=final_content,

src/understudy/metrics/builtins.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""Built-in metrics: standard evaluation metrics."""
22

3+
from ..models import Expectations
4+
from ..trace import Trace
35
from .registry import MetricRegistry, MetricResult
46

5-
if True:
6-
from ..models import Expectations
7-
from ..trace import Trace
8-
97

108
@MetricRegistry.register("efficiency", description="Token and latency efficiency metrics")
119
def compute_efficiency(trace: "Trace", expectations: "Expectations") -> MetricResult:

src/understudy/reports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,4 +277,5 @@ def log_message(self, format: str, *args: Any) -> None:
277277
server.serve_forever()
278278
except KeyboardInterrupt:
279279
print("\nShutting down...")
280+
finally:
280281
server.shutdown()

src/understudy/storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def save(
322322
trace_file = self.path / f"{trace_id}.json"
323323

324324
data = {
325-
"trace": json.loads(trace.model_dump_json()),
326-
"scene": json.loads(scene.model_dump_json()),
325+
"trace": trace.model_dump(mode="json"),
326+
"scene": scene.model_dump(mode="json"),
327327
"metadata": {
328328
"trace_id": trace_id,
329329
"scene_id": trace.scene_id,

tests/test_core.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,268 @@ def test_run_multiple_scenes_with_n_sims(self, tmp_path):
10241024
results = suite.run(app, storage=storage, n_sims=2)
10251025

10261026
assert len(results.results) == 4
1027+
1028+
1029+
# --- simulate_batch Tests ---
1030+
1031+
1032+
class TestSimulateBatch:
1033+
def test_simulate_batch_from_list(self, tmp_path):
1034+
from understudy import simulate_batch
1035+
1036+
scenes = [
1037+
Scene(
1038+
id=f"batch_scene_{i}",
1039+
starting_prompt="hello",
1040+
conversation_plan="greet",
1041+
persona=Persona(description="friendly"),
1042+
)
1043+
for i in range(2)
1044+
]
1045+
1046+
app = MockAgentApp()
1047+
traces = simulate_batch(app, scenes, n_sims=1, parallel=1)
1048+
1049+
assert len(traces) == 2
1050+
assert all(t.scene_id.startswith("batch_scene_") for t in traces)
1051+
1052+
def test_simulate_batch_with_n_sims(self, tmp_path):
1053+
from understudy import simulate_batch
1054+
1055+
scene = Scene(
1056+
id="nsims_batch_scene",
1057+
starting_prompt="hello",
1058+
conversation_plan="greet",
1059+
persona=Persona(description="friendly"),
1060+
)
1061+
1062+
app = MockAgentApp()
1063+
traces = simulate_batch(app, [scene], n_sims=3, parallel=1)
1064+
1065+
assert len(traces) == 3
1066+
assert all(t.scene_id == "nsims_batch_scene" for t in traces)
1067+
1068+
def test_simulate_batch_with_output(self, tmp_path):
1069+
from understudy import simulate_batch
1070+
1071+
scene = Scene(
1072+
id="output_scene",
1073+
starting_prompt="hello",
1074+
conversation_plan="greet",
1075+
persona=Persona(description="friendly"),
1076+
)
1077+
1078+
output_path = tmp_path / "traces"
1079+
app = MockAgentApp()
1080+
traces = simulate_batch(app, [scene], n_sims=2, output=output_path)
1081+
1082+
assert len(traces) == 2
1083+
1084+
storage = TraceStorage(path=output_path)
1085+
saved_traces = storage.list_traces()
1086+
assert len(saved_traces) == 2
1087+
1088+
def test_simulate_batch_with_tags(self, tmp_path):
1089+
from understudy import simulate_batch
1090+
1091+
scene = Scene(
1092+
id="tagged_scene",
1093+
starting_prompt="hello",
1094+
conversation_plan="greet",
1095+
persona=Persona(description="friendly"),
1096+
)
1097+
1098+
output_path = tmp_path / "traces"
1099+
app = MockAgentApp()
1100+
simulate_batch(app, [scene], n_sims=1, output=output_path, tags={"version": "v1"})
1101+
1102+
storage = TraceStorage(path=output_path)
1103+
trace_id = storage.list_traces()[0]
1104+
data = storage.load(trace_id)
1105+
assert data["metadata"]["tags"] == {"version": "v1"}
1106+
1107+
def test_simulate_batch_from_directory(self, tmp_path):
1108+
import yaml
1109+
1110+
from understudy import simulate_batch
1111+
1112+
scenes_dir = tmp_path / "scenes"
1113+
scenes_dir.mkdir()
1114+
1115+
for i in range(2):
1116+
scene_data = {
1117+
"id": f"file_scene_{i}",
1118+
"starting_prompt": "hello",
1119+
"conversation_plan": "greet",
1120+
"persona": "cooperative",
1121+
}
1122+
(scenes_dir / f"scene_{i}.yaml").write_text(yaml.dump(scene_data))
1123+
1124+
app = MockAgentApp()
1125+
traces = simulate_batch(app, scenes_dir, n_sims=1, parallel=1)
1126+
1127+
assert len(traces) == 2
1128+
1129+
def test_simulate_batch_from_single_file(self, tmp_path):
1130+
import yaml
1131+
1132+
from understudy import simulate_batch
1133+
1134+
scene_data = {
1135+
"id": "single_file_scene",
1136+
"starting_prompt": "hello",
1137+
"conversation_plan": "greet",
1138+
"persona": "cooperative",
1139+
}
1140+
scene_file = tmp_path / "scene.yaml"
1141+
scene_file.write_text(yaml.dump(scene_data))
1142+
1143+
app = MockAgentApp()
1144+
traces = simulate_batch(app, scene_file, n_sims=1)
1145+
1146+
assert len(traces) == 1
1147+
assert traces[0].scene_id == "single_file_scene"
1148+
1149+
def test_simulate_batch_parallel(self, tmp_path):
1150+
from understudy import simulate_batch
1151+
1152+
scenes = [
1153+
Scene(
1154+
id=f"parallel_scene_{i}",
1155+
starting_prompt="hello",
1156+
conversation_plan="greet",
1157+
persona=Persona(description="friendly"),
1158+
)
1159+
for i in range(3)
1160+
]
1161+
1162+
app = MockAgentApp()
1163+
traces = simulate_batch(app, scenes, n_sims=1, parallel=2)
1164+
1165+
assert len(traces) == 3
1166+
1167+
def test_simulate_batch_with_mocks(self, tmp_path):
1168+
from understudy import MockToolkit, simulate_batch
1169+
1170+
scene = Scene(
1171+
id="mock_scene",
1172+
starting_prompt="hello",
1173+
conversation_plan="greet",
1174+
persona=Persona(description="friendly"),
1175+
)
1176+
1177+
mocks = MockToolkit()
1178+
1179+
@mocks.handle("test_tool")
1180+
def test_tool():
1181+
return "mocked result"
1182+
1183+
app = MockAgentApp()
1184+
traces = simulate_batch(app, [scene], n_sims=1, mocks=mocks)
1185+
1186+
assert len(traces) == 1
1187+
1188+
1189+
# --- evaluate_batch comprehensive tests ---
1190+
1191+
1192+
class TestEvaluateBatchComprehensive:
1193+
def test_evaluate_batch_parallel(self, tmp_path):
1194+
traces = [
1195+
Trace(
1196+
scene_id=f"parallel_eval_{i}",
1197+
turns=[Turn(role="agent", content="done")],
1198+
terminal_state="completed",
1199+
)
1200+
for i in range(4)
1201+
]
1202+
expectations = Expectations(expected_resolution="completed")
1203+
1204+
results = evaluate_batch(traces, expectations=expectations, parallel=2)
1205+
1206+
assert len(results) == 4
1207+
assert all(r.passed for r in results)
1208+
1209+
def test_evaluate_batch_with_output(self, tmp_path):
1210+
traces = [
1211+
Trace(
1212+
scene_id="output_eval",
1213+
turns=[Turn(role="agent", content="done")],
1214+
terminal_state="completed",
1215+
)
1216+
]
1217+
expectations = Expectations(expected_resolution="completed")
1218+
output_path = tmp_path / "results"
1219+
1220+
results = evaluate_batch(traces, expectations=expectations, output=output_path)
1221+
1222+
assert len(results) == 1
1223+
result_storage = EvaluationStorage(path=output_path)
1224+
assert len(result_storage.list_results()) == 1
1225+
1226+
def test_evaluate_batch_with_metrics(self, tmp_path):
1227+
from understudy.trace import TraceMetrics, TurnMetrics
1228+
1229+
traces = [
1230+
Trace(
1231+
scene_id="metrics_eval",
1232+
turns=[Turn(role="agent", content="done")],
1233+
terminal_state="completed",
1234+
metrics=TraceMetrics(
1235+
turns=[TurnMetrics(input_tokens=100, output_tokens=50, latency_ms=500)]
1236+
),
1237+
)
1238+
]
1239+
expectations = Expectations()
1240+
1241+
results = evaluate_batch(traces, expectations=expectations, metrics=["efficiency"])
1242+
1243+
assert len(results) == 1
1244+
assert "efficiency" in results[0].check_result.metrics
1245+
1246+
def test_evaluate_batch_handles_exceptions(self, tmp_path):
1247+
class BadTrace:
1248+
scene_id = "bad"
1249+
1250+
traces = [BadTrace()] # type: ignore
1251+
expectations = Expectations()
1252+
1253+
results = evaluate_batch(traces, expectations=expectations) # type: ignore
1254+
1255+
assert len(results) == 1
1256+
assert results[0].error is not None
1257+
assert "Traceback" in results[0].error
1258+
1259+
def test_evaluate_batch_mixed_results(self, tmp_path):
1260+
traces = [
1261+
Trace(
1262+
scene_id="pass",
1263+
turns=[Turn(role="agent", content="done")],
1264+
terminal_state="completed",
1265+
),
1266+
Trace(
1267+
scene_id="fail",
1268+
turns=[Turn(role="agent", content="fail")],
1269+
terminal_state="failed",
1270+
),
1271+
]
1272+
expectations = Expectations(expected_resolution="completed")
1273+
1274+
results = evaluate_batch(traces, expectations=expectations)
1275+
1276+
assert len(results) == 2
1277+
passed = sum(1 for r in results if r.passed)
1278+
assert passed == 1
1279+
1280+
def test_evaluate_batch_without_expectations(self, tmp_path):
1281+
traces = [
1282+
Trace(
1283+
scene_id="no_exp",
1284+
turns=[Turn(role="agent", content="done")],
1285+
)
1286+
]
1287+
1288+
results = evaluate_batch(traces)
1289+
1290+
assert len(results) == 1
1291+
assert results[0].passed

0 commit comments

Comments
 (0)