From 05a5f9f20c39fd3d231baf98840df39973f088ac Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 3 Sep 2025 22:39:16 -0700 Subject: [PATCH] Ground truth now a json serializable type --- eval_protocol/benchmarks/test_aime25.py | 2 +- eval_protocol/benchmarks/test_livebench_data_analysis.py | 2 +- eval_protocol/models.py | 6 ++++-- tests/pytest/test_apps_coding.py | 2 +- tests/pytest/test_markdown_highlighting.py | 3 ++- tests/pytest/test_pytest_function_calling.py | 2 +- 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/eval_protocol/benchmarks/test_aime25.py b/eval_protocol/benchmarks/test_aime25.py index 3df32cec..91a67f77 100644 --- a/eval_protocol/benchmarks/test_aime25.py +++ b/eval_protocol/benchmarks/test_aime25.py @@ -99,7 +99,7 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow: extracted_text = _extract_boxed_text(content_str) extracted_int = _normalize_to_int_or_none(extracted_text) - gt_int = _normalize_to_int_or_none(row.ground_truth or "") + gt_int = _normalize_to_int_or_none(str(row.ground_truth)) is_valid = extracted_int is not None and gt_int is not None score = 1.0 if (is_valid and extracted_int == gt_int) else 0.0 diff --git a/eval_protocol/benchmarks/test_livebench_data_analysis.py b/eval_protocol/benchmarks/test_livebench_data_analysis.py index 8c8c5e3c..6baa7b89 100644 --- a/eval_protocol/benchmarks/test_livebench_data_analysis.py +++ b/eval_protocol/benchmarks/test_livebench_data_analysis.py @@ -407,7 +407,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: if row.ground_truth is None: return {"ground_truth": None, "release": None} try: - payload = json.loads(row.ground_truth) + payload = json.loads(str(row.ground_truth)) if isinstance(payload, dict): return payload except Exception: diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 4e526263..e9a6ca39 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -3,6 +3,8 @@ from enum import Enum from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union +JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None] + from openai.types import CompletionUsage from openai.types.chat.chat_completion_message import ( FunctionCall, @@ -598,8 +600,8 @@ class EvaluationRow(BaseModel): ) # Ground truth reference (moved from EvaluateResult to top level) - ground_truth: Optional[str] = Field( - default=None, description="Optional ground truth reference for this evaluation." + ground_truth: Optional[JSONType] = Field( + default=None, description="JSON-serializable ground truth reference for this evaluation." ) # Unified evaluation result diff --git a/tests/pytest/test_apps_coding.py b/tests/pytest/test_apps_coding.py index 9350a381..ef195791 100644 --- a/tests/pytest/test_apps_coding.py +++ b/tests/pytest/test_apps_coding.py @@ -47,7 +47,7 @@ def test_apps_code_evaluation(row: EvaluationRow) -> EvaluationRow: # Use evaluate_apps_solution directly result = evaluate_apps_solution( messages=row.messages, - ground_truth=row.ground_truth, + ground_truth=str(row.ground_truth), ) # Set the evaluation result on the row diff --git a/tests/pytest/test_markdown_highlighting.py b/tests/pytest/test_markdown_highlighting.py index c393ee60..cff328f0 100644 --- a/tests/pytest/test_markdown_highlighting.py +++ b/tests/pytest/test_markdown_highlighting.py @@ -42,12 +42,13 @@ def test_markdown_highlighting_evaluation(row: EvaluationRow) -> EvaluationRow: """ assistant_response = row.messages[-1].content + assistant_response = str(assistant_response or "") if not assistant_response: row.evaluation_result = EvaluateResult(score=0.0, reason="❌ No assistant response found") return row - required_highlights = int(row.ground_truth) + required_highlights = int(str(row.ground_truth)) # Check if the response contains the required number of formatted sections # e.g. **bold** or *italic* diff --git a/tests/pytest/test_pytest_function_calling.py b/tests/pytest/test_pytest_function_calling.py index 60f38b0d..0936a135 100644 --- a/tests/pytest/test_pytest_function_calling.py +++ b/tests/pytest/test_pytest_function_calling.py @@ -27,7 +27,7 @@ def function_calling_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evalu ) async def test_pytest_function_calling(row: EvaluationRow) -> EvaluationRow: """Run pointwise evaluation on sample dataset using pytest interface.""" - ground_truth = json.loads(row.ground_truth) + ground_truth = json.loads(str(row.ground_truth)) result = exact_tool_match_reward(row.messages, ground_truth) row.evaluation_result = result print(result)