Skip to content

Commit ca8b2e8

Browse files
authored
assert for evaluation_row (#339)
* assert error if evaluation_result not set * adjust comment * update * fix test * update more tests
1 parent f409213 commit ca8b2e8

20 files changed

+97
-18
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,13 @@ async def _collect_result(config, lst):
620620

621621
experiment_duration_seconds = time.perf_counter() - experiment_start_time
622622

623-
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
623+
if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
624+
raise AssertionError(
625+
"Some EvaluationRow instances are missing evaluation_result. "
626+
"Your @evaluation_test function must set `row.evaluation_result`"
627+
)
628+
629+
# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
624630
# rollout_id is used to differentiate the result from different completion_params
625631
if mode == "groupwise":
626632
results_by_group = [

tests/data_loader/test_dynamic_data_loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from eval_protocol.data_loader import DynamicDataLoader
2-
from eval_protocol.models import EvaluationRow, Message
2+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
33
from eval_protocol.pytest import evaluation_test
44

55

@@ -27,6 +27,7 @@ def test_dynamic_data_loader(row: EvaluationRow) -> EvaluationRow:
2727
== "Factory function that generates evaluation rows dynamically."
2828
)
2929
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
30+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
3031
return row
3132

3233

@@ -45,6 +46,7 @@ def test_dynamic_data_loader_lambda(row: EvaluationRow) -> EvaluationRow:
4546
assert row.input_metadata.dataset_info.get("data_loader_num_rows_after_preprocessing") == 1
4647
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
4748
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
49+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
4850
return row
4951

5052

@@ -72,4 +74,5 @@ def test_dynamic_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationR
7274
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
7375
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
7476

77+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
7578
return row

tests/data_loader/test_inline_data_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from eval_protocol.data_loader.inline_data_loader import InlineDataLoader
2-
from eval_protocol.models import EvaluationRow, Message
2+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
33
from eval_protocol.pytest import evaluation_test
44
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
55

@@ -20,6 +20,7 @@ def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow:
2020
assert row.input_metadata.dataset_info.get("data_loader_type") == "InlineDataLoader"
2121
assert row.input_metadata.dataset_info.get("data_loader_variant_description") is None
2222
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
23+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
2324
return row
2425

2526

@@ -41,4 +42,5 @@ def test_inline_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRo
4142
assert row.input_metadata.dataset_info.get("data_loader_type") == "InlineDataLoader"
4243
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
4344

45+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
4446
return row

tests/pytest/test_get_metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22

33
from eval_protocol.pytest import evaluation_test
4-
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
55

66

77
@evaluation_test(
@@ -22,6 +22,8 @@
2222
)
2323
def test_pytest_async(rows: list[EvaluationRow]) -> list[EvaluationRow]:
2424
"""Run math evaluation on sample dataset using pytest interface."""
25+
for row in rows:
26+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
2527
return rows
2628

2729

tests/pytest/test_pydantic_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydantic_ai.models.openai import OpenAIChatModel
33
import pytest
44

5-
from eval_protocol.models import EvaluationRow, Message, Status
5+
from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult
66
from eval_protocol.pytest import evaluation_test
77

88
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
@@ -28,4 +28,5 @@ async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
2828
Super simple hello world test for Pydantic AI.
2929
"""
3030
assert row.rollout_status.code == Status.Code.FINISHED
31+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
3132
return row

tests/pytest/test_pydantic_multi_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic_ai.models.openai import OpenAIChatModel
1111
import pytest
1212

13-
from eval_protocol.models import EvaluationRow, Message
13+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
1414
from eval_protocol.pytest import evaluation_test
1515
from pydantic_ai import Agent
1616

@@ -82,4 +82,5 @@ async def test_pydantic_multi_agent(row: EvaluationRow) -> EvaluationRow:
8282
"""
8383
Super simple hello world test for Pydantic AI.
8484
"""
85+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
8586
return row

tests/pytest/test_pytest_async.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from eval_protocol.models import EvaluationRow, Message
3+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
44
from eval_protocol.pytest import evaluation_test
55

66

@@ -20,6 +20,8 @@
2020
)
2121
async def test_pytest_async(rows: list[EvaluationRow]) -> list[EvaluationRow]:
2222
"""Run math evaluation on sample dataset using pytest interface."""
23+
for row in rows:
24+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
2325
return rows
2426

2527

@@ -36,6 +38,7 @@ async def test_pytest_async(rows: list[EvaluationRow]) -> list[EvaluationRow]:
3638
)
3739
async def test_pytest_async_pointwise(row: EvaluationRow) -> EvaluationRow:
3840
"""Run pointwise evaluation on sample dataset using pytest interface."""
41+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
3942
return row
4043

4144

tests/pytest/test_pytest_default_agent_rollout_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime
22
from typing import List
33

4-
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
55
from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test
66

77

@@ -24,4 +24,6 @@
2424
)
2525
def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2626
"""Run math evaluation on sample dataset using pytest interface."""
27+
for row in rows:
28+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
2729
return rows

tests/pytest/test_pytest_ensure_logging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def test_ensure_logging(monkeypatch):
2626
with patch(
2727
"eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store
2828
):
29-
from eval_protocol.models import EvaluationRow
29+
from eval_protocol.models import EvaluationRow, EvaluateResult
3030
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
3131
from eval_protocol.pytest.evaluation_test import evaluation_test
3232
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row
@@ -44,6 +44,9 @@ async def test_ensure_logging(monkeypatch):
4444
# Don't pass logger parameter - let it use the default_logger (which we've replaced)
4545
)
4646
def eval_fn(row: EvaluationRow) -> EvaluationRow:
47+
# This test is only about logging behavior; attach a dummy evaluation_result
48+
# so that evaluation_test's invariant about evaluation_result is satisfied.
49+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
4750
return row
4851

4952
await eval_fn(

tests/pytest/test_pytest_env_overwrite.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import atexit
22
import shutil
33
import tempfile
4-
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
55
from eval_protocol.pytest import evaluation_test
66
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
77
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
@@ -23,6 +23,7 @@ def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
2323
"""Run math evaluation on sample dataset using pytest interface."""
2424
assert row.messages[0].content == "What is the capital of France?"
2525
assert row.execution_metadata.invocation_id == "test-invocation-123"
26+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
2627
return row
2728

2829

@@ -38,6 +39,7 @@ def test_input_messages_in_env(row: EvaluationRow) -> EvaluationRow:
3839
"""Run math evaluation on sample dataset using pytest interface."""
3940
assert row.messages[0].content == "What is 5 * 6?"
4041
assert row.input_metadata.completion_params["model"] == "gpt-40"
42+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
4143
return row
4244

4345

@@ -60,6 +62,7 @@ def test_input_messages_in_env(row: EvaluationRow) -> EvaluationRow:
6062
)
6163
def test_input_override(row: EvaluationRow) -> EvaluationRow:
6264
assert row.messages[0].content == "What is 10 / 2?"
65+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
6366
return row
6467

6568

@@ -79,6 +82,7 @@ def test_no_op_rollout_processor_override_from_none(row: EvaluationRow) -> Evalu
7982
# Verify that no actual model call was made (NoOpRolloutProcessor doesn't modify messages)
8083
assert len(row.messages) == 1
8184
assert row.messages[0].role == "user"
85+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
8286
return row
8387

8488
@evaluation_test(
@@ -96,6 +100,7 @@ def test_no_op_rollout_processor_override_from_other(row: EvaluationRow) -> Eval
96100
assert row.messages[0].role == "user"
97101
# Verify the original message content is preserved (no assistant response added)
98102
assert row.messages[0].content == "Test override"
103+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
99104
return row
100105

101106
@evaluation_test(
@@ -115,6 +120,7 @@ def test_no_op_rollout_processor_override_multiple_rows(row: EvaluationRow) -> E
115120
# Verify rows pass through unchanged
116121
assert len(row.messages) == 1
117122
assert row.messages[0].role == "user"
123+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
118124
return row
119125

120126

0 commit comments

Comments
 (0)