diff --git a/eval_protocol/pytest/evaluation_test_postprocess.py b/eval_protocol/pytest/evaluation_test_postprocess.py index fe27ad22..838ae4cd 100644 --- a/eval_protocol/pytest/evaluation_test_postprocess.py +++ b/eval_protocol/pytest/evaluation_test_postprocess.py @@ -39,13 +39,17 @@ def postprocess( ] agg_score = aggregate(scores, aggregation_method) + # Calculate raw score (total score / total rows, including invalid scores) + all_scores = [r.evaluation_result.score for sublist in all_results for r in sublist if r.evaluation_result] + raw_score = sum(all_scores) / len(all_scores) if all_scores else 0.0 + # Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats) ci_low: float | None = None ci_high: float | None = None standard_error: float | None = None if aggregation_method == "mean": try: - result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist]) + result_ci = compute_fixed_set_mu_ci([item for sublist in valid_results for item in sublist]) _, mu_ci_low, mu_ci_high, se = result_ci if mu_ci_low is not None and mu_ci_high is not None and se is not None: ci_low = float(mu_ci_low) @@ -140,12 +144,17 @@ def postprocess( if should_print: if ci_low is not None and ci_high is not None and standard_error is not None: print( - f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}", + f"EP Summary | suite={suite_name} model={model_used} runs={num_runs} rows={total_rows}\n" + f" agg_score={summary_obj['agg_score']:.3f} (valid scores only)\n" + f" raw_score={raw_score:.3f} (invalid scores as 0)\n" + f" se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}]", file=sys.__stderr__, ) else: print( - f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}", + f"EP Summary | suite={suite_name} model={model_used} runs={num_runs} rows={total_rows}\n" + f" agg_score={summary_obj['agg_score']:.3f} (valid scores only)\n" + f" raw_score={raw_score:.3f} (invalid scores as 0)", file=sys.__stderr__, ) # As per project convention, avoid printing per-metric CI lines to reduce noise diff --git a/tests/test_evaluation_postprocess.py b/tests/test_evaluation_postprocess.py index 1bbdb51a..7d18205c 100644 --- a/tests/test_evaluation_postprocess.py +++ b/tests/test_evaluation_postprocess.py @@ -3,8 +3,9 @@ import pytest from unittest.mock import Mock, patch -from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata +from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata, Message from eval_protocol.pytest.evaluation_test_postprocess import postprocess +from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci class TestPostprocess: @@ -205,3 +206,236 @@ def test_all_invalid_scores(self): # Should still call logger.log for all rows assert mock_logger.log.call_count == 2 + + +class TestComputeFixedSetMuCi: + """Tests for compute_fixed_set_mu_ci function.""" + + @patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads + def test_compute_fixed_set_mu_ci_with_flattened_results(self): + """Test that postprocess correctly calls compute_fixed_set_mu_ci with flattened all_results structure.""" + + q1_run1 = EvaluationRow( + messages=[Message(role="user", content="What is 2+2?")], + evaluation_result=EvaluateResult(score=0.5, is_score_valid=True, reason="correct"), + input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q1_run2 = EvaluationRow( + messages=[Message(role="user", content="What is 2+2?")], + evaluation_result=EvaluateResult(score=0.4, is_score_valid=True, reason="incorrect"), + input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q1_run3 = EvaluationRow( + messages=[Message(role="user", content="What is 2+2?")], + evaluation_result=EvaluateResult(score=0.45, is_score_valid=True, reason="incorrect"), + input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q2_run1 = EvaluationRow( + messages=[Message(role="user", content="What is 3+3?")], + evaluation_result=EvaluateResult(score=0.8, is_score_valid=True, reason="incorrect"), + input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q2_run2 = EvaluationRow( + messages=[Message(role="user", content="What is 3+3?")], + evaluation_result=EvaluateResult(score=0.9, is_score_valid=True, reason="correct"), + input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q2_run3 = EvaluationRow( + messages=[Message(role="user", content="What is 3+3?")], + evaluation_result=EvaluateResult(score=0.95, is_score_valid=True, reason="correct"), + input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q3_run1 = EvaluationRow( + messages=[Message(role="user", content="What is 4+4?")], + evaluation_result=EvaluateResult(score=0.1, is_score_valid=True, reason="incorrect"), + input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q3_run2 = EvaluationRow( + messages=[Message(role="user", content="What is 4+4?")], + evaluation_result=EvaluateResult(score=0.2, is_score_valid=True, reason="correct"), + input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q3_run3_valid = EvaluationRow( + messages=[Message(role="user", content="What is 4+4?")], + evaluation_result=EvaluateResult(score=0.3, is_score_valid=True, reason="correct"), + input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + q3_run3_invalid = EvaluationRow( + messages=[Message(role="user", content="What is 4+4?")], + evaluation_result=EvaluateResult(score=0.3, is_score_valid=False, reason="correct"), + input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}), + execution_metadata=ExecutionMetadata(), + eval_metadata=EvalMetadata( + name="test", + description="test", + version="1.0", + status=None, + num_runs=3, + aggregation_method="mean", + passed_threshold=None, + passed=None, + ), + ) + + rows = [[q1_run1, q2_run1, q3_run1], [q1_run2, q2_run2, q1_run3], [q2_run3, q3_run2, q3_run3_valid]] + rows_with_invalid_score = [ + [q1_run1, q2_run1, q3_run1], + [q1_run2, q2_run2, q1_run3], + [q2_run3, q3_run2, q3_run3_invalid], + ] + + # Store results for assertions + first_result = None + second_result = None + + # Test first case (all valid scores) + with patch("eval_protocol.pytest.evaluation_test_postprocess.compute_fixed_set_mu_ci") as mock_ci: + mock_ci.side_effect = lambda input_rows, **kwargs: compute_fixed_set_mu_ci(input_rows, **kwargs) + + postprocess( + all_results=rows, + aggregation_method="mean", + threshold=None, + active_logger=Mock(), + mode="pointwise", + completion_params={"model": "test-model"}, + test_func_name="test_ci_flattened", + num_runs=3, + experiment_duration_seconds=10.0, + ) + + first_result = mock_ci.return_value + + # Test second case (with invalid score) + with patch("eval_protocol.pytest.evaluation_test_postprocess.compute_fixed_set_mu_ci") as mock_ci: + mock_ci.side_effect = lambda input_rows, **kwargs: compute_fixed_set_mu_ci(input_rows, **kwargs) + + postprocess( + all_results=rows_with_invalid_score, + aggregation_method="mean", + threshold=None, + active_logger=Mock(), + mode="pointwise", + completion_params={"model": "test-model"}, + test_func_name="test_ci_flattened_invalid", + num_runs=3, + experiment_duration_seconds=10.0, + ) + + second_result = mock_ci.return_value + + # Assert exact values + # First case: (0.5111111111111111, 0.18101430525778583, 0.8412079169644363, 0.168416737680268) + if first_result and len(first_result) == 4: + mu_hat1, ci_low1, ci_high1, se1 = first_result + assert abs(mu_hat1 - 0.5111111111111111) < 1e-10 + assert abs(ci_low1 - 0.18101430525778583) < 1e-10 + assert abs(ci_high1 - 0.8412079169644363) < 1e-10 + assert abs(se1 - 0.168416737680268) < 1e-10 + + # Second case: (0.49444444444444446, 0.13494616580367125, 0.8539427230852177, 0.18341748910243533) + if second_result and len(second_result) == 4: + mu_hat2, ci_low2, ci_high2, se2 = second_result + assert abs(mu_hat2 - 0.49444444444444446) < 1e-10 + assert abs(ci_low2 - 0.13494616580367125) < 1e-10 + assert abs(ci_high2 - 0.8539427230852177) < 1e-10 + assert abs(se2 - 0.18341748910243533) < 1e-10