Skip to content

Commit 46b3925

Browse files
committed
assert error if evaluation_result not set
1 parent f409213 commit 46b3925

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,13 @@ async def _collect_result(config, lst):
620620

621621
experiment_duration_seconds = time.perf_counter() - experiment_start_time
622622

623-
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
623+
if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
624+
raise AssertionError(
625+
"Some EvaluationRow instances are missing evaluation_result. "
626+
"Your @evaluation_test function (or its rollout processor) must set `row.evaluation_result`"
627+
)
628+
629+
# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
624630
# rollout_id is used to differentiate the result from different completion_params
625631
if mode == "groupwise":
626632
results_by_group = [
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
5+
from eval_protocol.pytest.evaluation_test import evaluation_test
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_missing_evaluation_result_raises_assertion_error() -> None:
10+
"""evaluation_test should raise if any EvaluationRow is missing evaluation_result."""
11+
12+
input_messages = [
13+
[Message(role="user", content="Test message")],
14+
]
15+
16+
@evaluation_test(
17+
input_messages=[input_messages],
18+
rollout_processor=NoOpRolloutProcessor(),
19+
mode="pointwise",
20+
num_runs=1,
21+
)
22+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
23+
# Intentionally forget to set row.evaluation_result
24+
return row
25+
26+
with pytest.raises(AssertionError) as excinfo:
27+
# Trigger the evaluation; this should hit the assertion added in evaluation_test.py
28+
await eval_fn(input_messages=input_messages) # pyright: ignore[reportCallIssue]
29+
30+
msg = str(excinfo.value)
31+
assert "Some EvaluationRow instances are missing evaluation_result" in msg
32+
assert "must set `row.evaluation_result`" in msg

0 commit comments

Comments
 (0)