Skip to content

Commit 6e881ff

Browse files
committed
update
1 parent 163e3a9 commit 6e881ff

File tree

4 files changed

+321
-23
lines changed

4 files changed

+321
-23
lines changed

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ def postprocess(
2525
num_runs: int,
2626
experiment_duration_seconds: float,
2727
):
28+
valid_results = [
29+
[r for r in result if r.evaluation_result and r.evaluation_result.is_score_valid] for result in all_results
30+
]
31+
2832
if aggregation_method == "bootstrap":
29-
scores = [
30-
r.evaluation_result.score
31-
for result in all_results
32-
for r in result
33-
if r.evaluation_result and r.evaluation_result.is_score_valid
34-
]
33+
scores = [r.evaluation_result.score for result in valid_results for r in result if r.evaluation_result]
3534
else:
3635
scores = [
37-
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
38-
for result in all_results
36+
sum(r.evaluation_result.score for r in result if r.evaluation_result) / len(result)
37+
for result in valid_results
38+
if result
3939
]
4040
agg_score = aggregate(scores, aggregation_method)
4141

eval_protocol/quickstart/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,9 @@ def assistant_to_ground_truth(data: list[EvaluationRow]) -> list[EvaluationRow]:
168168
messages = row.messages.copy() # Don't modify original
169169

170170
if messages[-1].role == "assistant":
171+
assistant_message = messages[-1]
171172
messages = messages[:-1]
172-
ground_truth_message = serialize_message(messages[-1])
173+
ground_truth_message = serialize_message(assistant_message)
173174
else:
174175
raise ValueError("Last message is not from assistant")
175176

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""Tests for evaluation postprocess functionality."""
2+
3+
import pytest
4+
from unittest.mock import Mock, patch
5+
6+
from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata
7+
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
8+
9+
10+
class TestPostprocess:
11+
"""Tests for postprocess function."""
12+
13+
def create_test_row(self, score: float, is_valid: bool = True) -> EvaluationRow:
14+
"""Helper to create a test evaluation row."""
15+
return EvaluationRow(
16+
messages=[],
17+
evaluation_result=EvaluateResult(score=score, is_score_valid=is_valid, reason="test"),
18+
input_metadata=InputMetadata(completion_params={"model": "test-model"}),
19+
execution_metadata=ExecutionMetadata(),
20+
eval_metadata=EvalMetadata(
21+
name="test",
22+
description="test",
23+
version="1.0",
24+
status=None,
25+
num_runs=1,
26+
aggregation_method="mean",
27+
passed_threshold=None,
28+
passed=None,
29+
),
30+
)
31+
32+
@patch.dict("os.environ", {"EP_SUMMARY_JSON": ""}) # Disable uploads
33+
def test_bootstrap_aggregation_with_valid_scores(self):
34+
"""Test bootstrap aggregation with all valid scores and verify exact scores list."""
35+
# Create test data: 2 runs with 2 rows each
36+
all_results = [
37+
[self.create_test_row(0.8), self.create_test_row(0.6)], # Run 1
38+
[self.create_test_row(0.7), self.create_test_row(0.9)], # Run 2
39+
]
40+
41+
mock_logger = Mock()
42+
43+
# Mock the aggregate function to capture the exact scores passed to it
44+
with patch("eval_protocol.pytest.evaluation_test_postprocess.aggregate") as mock_aggregate:
45+
mock_aggregate.return_value = 0.75 # Mock return value
46+
47+
postprocess(
48+
all_results=all_results,
49+
aggregation_method="bootstrap",
50+
threshold=None,
51+
active_logger=mock_logger,
52+
mode="pointwise",
53+
completion_params={"model": "test-model"},
54+
test_func_name="test_bootstrap",
55+
num_runs=2,
56+
experiment_duration_seconds=10.0,
57+
)
58+
59+
# Check that aggregate was called with all individual scores in order
60+
mock_aggregate.assert_called_once_with([0.8, 0.6, 0.7, 0.9], "bootstrap")
61+
62+
# Should call logger.log for each row
63+
assert mock_logger.log.call_count == 4
64+
65+
@patch.dict("os.environ", {"EP_SUMMARY_JSON": ""}) # Disable uploads
66+
def test_bootstrap_aggregation_filters_invalid_scores(self):
67+
"""Test that bootstrap aggregation excludes invalid scores and generates correct scores list."""
68+
# Create test data with some invalid scores
69+
all_results = [
70+
[
71+
self.create_test_row(0.8, is_valid=True),
72+
self.create_test_row(0.0, is_valid=False), # Invalid - should be excluded
73+
],
74+
[
75+
self.create_test_row(0.7, is_valid=True),
76+
self.create_test_row(0.0, is_valid=False), # Invalid - should be excluded
77+
],
78+
]
79+
80+
mock_logger = Mock()
81+
82+
# Mock the aggregate function to capture the scores passed to it
83+
with patch("eval_protocol.pytest.evaluation_test_postprocess.aggregate") as mock_aggregate:
84+
mock_aggregate.return_value = 0.75 # Mock return value
85+
86+
postprocess(
87+
all_results=all_results,
88+
aggregation_method="bootstrap",
89+
threshold=None,
90+
active_logger=mock_logger,
91+
mode="pointwise",
92+
completion_params={"model": "test-model"},
93+
test_func_name="test_bootstrap_invalid",
94+
num_runs=2,
95+
experiment_duration_seconds=10.0,
96+
)
97+
98+
# Check that aggregate was called with only valid scores
99+
mock_aggregate.assert_called_once_with([0.8, 0.7], "bootstrap")
100+
101+
# Should still call logger.log for all rows (including invalid ones)
102+
assert mock_logger.log.call_count == 4
103+
104+
@patch.dict("os.environ", {"EP_SUMMARY_JSON": ""}) # Disable uploads
105+
def test_mean_aggregation_with_valid_scores(self):
106+
"""Test mean aggregation with all valid scores."""
107+
all_results = [
108+
[self.create_test_row(0.8), self.create_test_row(0.6)], # Run 1: mean = 0.7
109+
[self.create_test_row(0.4), self.create_test_row(0.8)], # Run 2: mean = 0.6
110+
]
111+
112+
mock_logger = Mock()
113+
114+
postprocess(
115+
all_results=all_results,
116+
aggregation_method="mean",
117+
threshold=None,
118+
active_logger=mock_logger,
119+
mode="pointwise",
120+
completion_params={"model": "test-model"},
121+
test_func_name="test_mean",
122+
num_runs=2,
123+
experiment_duration_seconds=10.0,
124+
)
125+
126+
# Should call logger.log for each row
127+
assert mock_logger.log.call_count == 4
128+
129+
@patch.dict("os.environ", {"EP_SUMMARY_JSON": ""}) # Disable uploads
130+
def test_mean_aggregation_filters_invalid_scores(self):
131+
"""Test that mean aggregation excludes invalid scores from run averages."""
132+
all_results = [
133+
[
134+
self.create_test_row(0.8, is_valid=True),
135+
self.create_test_row(0.0, is_valid=False), # Invalid - excluded from run average
136+
],
137+
[
138+
self.create_test_row(0.6, is_valid=True),
139+
self.create_test_row(0.4, is_valid=True),
140+
],
141+
]
142+
143+
mock_logger = Mock()
144+
145+
postprocess(
146+
all_results=all_results,
147+
aggregation_method="mean",
148+
threshold=None,
149+
active_logger=mock_logger,
150+
mode="pointwise",
151+
completion_params={"model": "test-model"},
152+
test_func_name="test_mean_invalid",
153+
num_runs=2,
154+
experiment_duration_seconds=10.0,
155+
)
156+
157+
# Should call logger.log for all rows
158+
assert mock_logger.log.call_count == 4
159+
160+
@patch.dict("os.environ", {"EP_SUMMARY_JSON": ""}) # Disable uploads
161+
def test_empty_runs_are_skipped(self):
162+
"""Test that runs with no valid scores are skipped."""
163+
all_results = [
164+
[self.create_test_row(0.8, is_valid=True)], # Run 1: has valid score
165+
[self.create_test_row(0.0, is_valid=False)], # Run 2: no valid scores - should be skipped
166+
]
167+
168+
mock_logger = Mock()
169+
170+
postprocess(
171+
all_results=all_results,
172+
aggregation_method="mean",
173+
threshold=None,
174+
active_logger=mock_logger,
175+
mode="pointwise",
176+
completion_params={"model": "test-model"},
177+
test_func_name="test_empty_runs",
178+
num_runs=2,
179+
experiment_duration_seconds=10.0,
180+
)
181+
182+
# Should still call logger.log for all rows
183+
assert mock_logger.log.call_count == 2
184+
185+
@patch.dict("os.environ", {"EP_SUMMARY_JSON": ""}) # Disable uploads
186+
def test_all_invalid_scores(self):
187+
"""Test behavior when all scores are invalid."""
188+
all_results = [
189+
[self.create_test_row(0.0, is_valid=False), self.create_test_row(0.0, is_valid=False)],
190+
]
191+
192+
mock_logger = Mock()
193+
194+
postprocess(
195+
all_results=all_results,
196+
aggregation_method="bootstrap",
197+
threshold=None,
198+
active_logger=mock_logger,
199+
mode="pointwise",
200+
completion_params={"model": "test-model"},
201+
test_func_name="test_all_invalid",
202+
num_runs=1,
203+
experiment_duration_seconds=10.0,
204+
)
205+
206+
# Should still call logger.log for all rows
207+
assert mock_logger.log.call_count == 2

0 commit comments

Comments
 (0)