diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 672e8fe4..763e3081 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -24,7 +24,7 @@ ) from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper from eval_protocol.pytest.evaluation_test_postprocess import postprocess -from eval_protocol.pytest.execution import execute_pytest +from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling from eval_protocol.pytest.generate_parameter_combinations import ( ParameterizedTestKwargs, generate_parameter_combinations, @@ -434,23 +434,11 @@ async def _execute_pointwise_eval_with_semaphore( experiment_id=experiment_id, run_id=run_id, ): - try: - result = await execute_pytest( - test_func, - processed_row=row, - evaluation_test_kwargs=evaluation_test_kwargs, - ) - except Exception as e: - result = row - result.evaluation_result = EvaluateResult( - score=0.0, - is_score_valid=False, - reason=f"Error during evaluation: {type(e).__name__}: {e}", - ) - if result.eval_metadata is not None: - result.eval_metadata.status = Status.error( - f"Error during evaluation: {type(e).__name__}: {e}", - ) + result = await execute_pytest_with_exception_handling( + test_func=test_func, + evaluation_test_kwargs=evaluation_test_kwargs, + processed_row=row, + ) if not isinstance(result, EvaluationRow): raise ValueError( f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." @@ -472,24 +460,11 @@ async def _execute_groupwise_eval_with_semaphore( run_id=run_id, rollout_ids=group_rollout_ids or None, ): - try: - results = await execute_pytest( - test_func, - processed_dataset=rows, - evaluation_test_kwargs=evaluation_test_kwargs, - ) - except Exception as e: - results = rows - for row in results: - row.evaluation_result = EvaluateResult( - score=0.0, - is_score_valid=False, - reason=f"Error during evaluation: {type(e).__name__}: {e}", - ) - if row.eval_metadata is not None: - row.eval_metadata.status = Status.error( - f"Error during evaluation: {type(e).__name__}: {e}", - ) + results = await execute_pytest_with_exception_handling( + test_func=test_func, + evaluation_test_kwargs=evaluation_test_kwargs, + processed_dataset=rows, + ) if not isinstance(results, list): raise ValueError( f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." @@ -580,10 +555,10 @@ async def _collect_result(config, lst): run_id=run_id, rollout_ids=group_rollout_ids or None, ): - results = await execute_pytest( - test_func, - processed_dataset=input_dataset, + results = await execute_pytest_with_exception_handling( + test_func=test_func, evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + processed_dataset=input_dataset, ) if ( results is None diff --git a/eval_protocol/pytest/execution.py b/eval_protocol/pytest/execution.py index fa572ee0..dabe08e4 100644 --- a/eval_protocol/pytest/execution.py +++ b/eval_protocol/pytest/execution.py @@ -1,7 +1,8 @@ import asyncio +import os from collections.abc import Awaitable, Callable -from typing import cast -from eval_protocol.models import EvaluationRow +from typing import Any, cast +from eval_protocol.models import EvaluationRow, EvaluateResult, Status from eval_protocol.pytest.types import Dataset, EvaluationInputParam, TestFunction @@ -41,3 +42,70 @@ async def execute_pytest( return test_func(processed_dataset, **evaluation_test_kwargs) test_func = cast(Callable[[], EvaluationRow], test_func) return test_func(**evaluation_test_kwargs) + + +async def execute_pytest_with_exception_handling( + test_func: TestFunction, + evaluation_test_kwargs: dict[str, Any], + processed_row: EvaluationRow | None = None, + processed_dataset: list[EvaluationRow] | None = None, +) -> EvaluationRow | list[EvaluationRow]: + """Helper function to execute pytest with consistent exception handling. + + Args: + test_func: The test function to execute + evaluation_test_kwargs: Kwargs for the evaluation function + processed_row: Single row for pointwise evaluation (mutually exclusive with processed_dataset) + processed_dataset: Dataset for groupwise/all evaluation (mutually exclusive with processed_row) + + Returns: + The result of execute_pytest, or the input data with error results on exception + """ + try: + if processed_row is not None: + return await execute_pytest( + test_func, + processed_row=processed_row, + evaluation_test_kwargs=evaluation_test_kwargs, + ) + else: + return await execute_pytest( + test_func, + processed_dataset=processed_dataset, + evaluation_test_kwargs=evaluation_test_kwargs, + ) + except Exception as e: + if os.getenv("EP_RAISE_EVAL_EXCEPTIONS", "true").strip() == "false": + # Handle single row case + if processed_row is not None: + result = processed_row + result.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason=f"Error during evaluation: {type(e).__name__}: {e}", + ) + if result.eval_metadata is not None: + result.eval_metadata.status = Status.error( + f"Error during evaluation: {type(e).__name__}: {e}", + ) + return result + # Handle list of rows case + elif processed_dataset is not None: + results = processed_dataset + for row in results: + row.evaluation_result = EvaluateResult( + score=0.0, + is_score_valid=False, + reason=f"Error during evaluation: {type(e).__name__}: {e}", + ) + if row.eval_metadata is not None: + row.eval_metadata.status = Status.error( + f"Error during evaluation: {type(e).__name__}: {e}", + ) + return results + else: + # This should never happen since one of processed_row/processed_dataset must be provided + raise ValueError("Neither processed_row nor processed_dataset was provided") + # Default: raise exceptions unless explicitly disabled + else: + raise diff --git a/tests/pytest/test_pytest_evaluator_error_handling.py b/tests/pytest/test_pytest_evaluator_error_handling.py index 70861679..5a412c5d 100644 --- a/tests/pytest/test_pytest_evaluator_error_handling.py +++ b/tests/pytest/test_pytest_evaluator_error_handling.py @@ -25,6 +25,15 @@ from eval_protocol.dataset_logger.dataset_logger import DatasetLogger +@pytest.fixture(autouse=True) +def _force_catch_eval_exceptions(monkeypatch: pytest.MonkeyPatch): + """ + These tests validate the behavior when evaluation exceptions are caught and converted + into evaluation_result/status fields. Ensure the env var is set to disable raising. + """ + monkeypatch.setenv("EP_RAISE_EVAL_EXCEPTIONS", "false") + + class TrackingLogger(DatasetLogger): """Custom logger that tracks all logged rows for testing."""