Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions eval_protocol/pytest/evaluation_test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,17 @@ def postprocess(
]
agg_score = aggregate(scores, aggregation_method)

# Calculate raw score (total score / total rows, including invalid scores)
all_scores = [r.evaluation_result.score for sublist in all_results for r in sublist if r.evaluation_result]
raw_score = sum(all_scores) / len(all_scores) if all_scores else 0.0

# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
ci_low: float | None = None
ci_high: float | None = None
standard_error: float | None = None
if aggregation_method == "mean":
try:
result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist])
result_ci = compute_fixed_set_mu_ci([item for sublist in valid_results for item in sublist])
_, mu_ci_low, mu_ci_high, se = result_ci
if mu_ci_low is not None and mu_ci_high is not None and se is not None:
ci_low = float(mu_ci_low)
Expand Down Expand Up @@ -140,12 +144,17 @@ def postprocess(
if should_print:
if ci_low is not None and ci_high is not None and standard_error is not None:
print(
f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}",
f"EP Summary | suite={suite_name} model={model_used} runs={num_runs} rows={total_rows}\n"
f" agg_score={summary_obj['agg_score']:.3f} (valid scores only)\n"
f" raw_score={raw_score:.3f} (invalid scores as 0)\n"
f" se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}]",
file=sys.__stderr__,
)
else:
print(
f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}",
f"EP Summary | suite={suite_name} model={model_used} runs={num_runs} rows={total_rows}\n"
f" agg_score={summary_obj['agg_score']:.3f} (valid scores only)\n"
f" raw_score={raw_score:.3f} (invalid scores as 0)",
file=sys.__stderr__,
)
# As per project convention, avoid printing per-metric CI lines to reduce noise
Expand Down
236 changes: 235 additions & 1 deletion tests/test_evaluation_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pytest
from unittest.mock import Mock, patch

from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata
from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata, Message
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci


class TestPostprocess:
Expand Down Expand Up @@ -205,3 +206,236 @@ def test_all_invalid_scores(self):

# Should still call logger.log for all rows
assert mock_logger.log.call_count == 2


class TestComputeFixedSetMuCi:
"""Tests for compute_fixed_set_mu_ci function."""

@patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads
def test_compute_fixed_set_mu_ci_with_flattened_results(self):
"""Test that postprocess correctly calls compute_fixed_set_mu_ci with flattened all_results structure."""

q1_run1 = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?")],
evaluation_result=EvaluateResult(score=0.5, is_score_valid=True, reason="correct"),
input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q1_run2 = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?")],
evaluation_result=EvaluateResult(score=0.4, is_score_valid=True, reason="incorrect"),
input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q1_run3 = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?")],
evaluation_result=EvaluateResult(score=0.45, is_score_valid=True, reason="incorrect"),
input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q2_run1 = EvaluationRow(
messages=[Message(role="user", content="What is 3+3?")],
evaluation_result=EvaluateResult(score=0.8, is_score_valid=True, reason="incorrect"),
input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q2_run2 = EvaluationRow(
messages=[Message(role="user", content="What is 3+3?")],
evaluation_result=EvaluateResult(score=0.9, is_score_valid=True, reason="correct"),
input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q2_run3 = EvaluationRow(
messages=[Message(role="user", content="What is 3+3?")],
evaluation_result=EvaluateResult(score=0.95, is_score_valid=True, reason="correct"),
input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q3_run1 = EvaluationRow(
messages=[Message(role="user", content="What is 4+4?")],
evaluation_result=EvaluateResult(score=0.1, is_score_valid=True, reason="incorrect"),
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q3_run2 = EvaluationRow(
messages=[Message(role="user", content="What is 4+4?")],
evaluation_result=EvaluateResult(score=0.2, is_score_valid=True, reason="correct"),
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q3_run3_valid = EvaluationRow(
messages=[Message(role="user", content="What is 4+4?")],
evaluation_result=EvaluateResult(score=0.3, is_score_valid=True, reason="correct"),
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)
q3_run3_invalid = EvaluationRow(
messages=[Message(role="user", content="What is 4+4?")],
evaluation_result=EvaluateResult(score=0.3, is_score_valid=False, reason="correct"),
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
execution_metadata=ExecutionMetadata(),
eval_metadata=EvalMetadata(
name="test",
description="test",
version="1.0",
status=None,
num_runs=3,
aggregation_method="mean",
passed_threshold=None,
passed=None,
),
)

rows = [[q1_run1, q2_run1, q3_run1], [q1_run2, q2_run2, q1_run3], [q2_run3, q3_run2, q3_run3_valid]]
rows_with_invalid_score = [
[q1_run1, q2_run1, q3_run1],
[q1_run2, q2_run2, q1_run3],
[q2_run3, q3_run2, q3_run3_invalid],
]

# Store results for assertions
first_result = None
second_result = None

# Test first case (all valid scores)
with patch("eval_protocol.pytest.evaluation_test_postprocess.compute_fixed_set_mu_ci") as mock_ci:
mock_ci.side_effect = lambda input_rows, **kwargs: compute_fixed_set_mu_ci(input_rows, **kwargs)

postprocess(
all_results=rows,
aggregation_method="mean",
threshold=None,
active_logger=Mock(),
mode="pointwise",
completion_params={"model": "test-model"},
test_func_name="test_ci_flattened",
num_runs=3,
experiment_duration_seconds=10.0,
)

first_result = mock_ci.return_value

# Test second case (with invalid score)
with patch("eval_protocol.pytest.evaluation_test_postprocess.compute_fixed_set_mu_ci") as mock_ci:
mock_ci.side_effect = lambda input_rows, **kwargs: compute_fixed_set_mu_ci(input_rows, **kwargs)

postprocess(
all_results=rows_with_invalid_score,
aggregation_method="mean",
threshold=None,
active_logger=Mock(),
mode="pointwise",
completion_params={"model": "test-model"},
test_func_name="test_ci_flattened_invalid",
num_runs=3,
experiment_duration_seconds=10.0,
)

second_result = mock_ci.return_value

# Assert exact values
# First case: (0.5111111111111111, 0.18101430525778583, 0.8412079169644363, 0.168416737680268)
if first_result and len(first_result) == 4:
mu_hat1, ci_low1, ci_high1, se1 = first_result
assert abs(mu_hat1 - 0.5111111111111111) < 1e-10
assert abs(ci_low1 - 0.18101430525778583) < 1e-10
assert abs(ci_high1 - 0.8412079169644363) < 1e-10
assert abs(se1 - 0.168416737680268) < 1e-10

# Second case: (0.49444444444444446, 0.13494616580367125, 0.8539427230852177, 0.18341748910243533)
if second_result and len(second_result) == 4:
mu_hat2, ci_low2, ci_high2, se2 = second_result
assert abs(mu_hat2 - 0.49444444444444446) < 1e-10
assert abs(ci_low2 - 0.13494616580367125) < 1e-10
assert abs(ci_high2 - 0.8539427230852177) < 1e-10
assert abs(se2 - 0.18341748910243533) < 1e-10
Loading