Skip to content

Commit 4b31337

Browse files
committed
better abstraction
1 parent d73e558 commit 4b31337

File tree

2 files changed

+86
-74
lines changed

2 files changed

+86
-74
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 16 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
2626
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
27-
from eval_protocol.pytest.execution import execute_pytest
27+
from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling
2828
from eval_protocol.pytest.generate_parameter_combinations import (
2929
ParameterizedTestKwargs,
3030
generate_parameter_combinations,
@@ -434,29 +434,11 @@ async def _execute_pointwise_eval_with_semaphore(
434434
experiment_id=experiment_id,
435435
run_id=run_id,
436436
):
437-
try:
438-
result = await execute_pytest(
439-
test_func,
440-
processed_row=row,
441-
evaluation_test_kwargs=evaluation_test_kwargs,
442-
)
443-
except AssertionError:
444-
raise
445-
except Exception as e:
446-
# Default: capture non-assert exceptions unless explicitly disabled
447-
if os.getenv("EP_RAISE_EVAL_EXCEPTIONS", "false").strip() == "false":
448-
result = row
449-
result.evaluation_result = EvaluateResult(
450-
score=0.0,
451-
is_score_valid=False,
452-
reason=f"Error during evaluation: {type(e).__name__}: {e}",
453-
)
454-
if result.eval_metadata is not None:
455-
result.eval_metadata.status = Status.error(
456-
f"Error during evaluation: {type(e).__name__}: {e}",
457-
)
458-
else:
459-
raise
437+
result = await execute_pytest_with_exception_handling(
438+
test_func=test_func,
439+
evaluation_test_kwargs=evaluation_test_kwargs,
440+
processed_row=row,
441+
)
460442
if not isinstance(result, EvaluationRow):
461443
raise ValueError(
462444
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
@@ -478,30 +460,11 @@ async def _execute_groupwise_eval_with_semaphore(
478460
run_id=run_id,
479461
rollout_ids=group_rollout_ids or None,
480462
):
481-
try:
482-
results = await execute_pytest(
483-
test_func,
484-
processed_dataset=rows,
485-
evaluation_test_kwargs=evaluation_test_kwargs,
486-
)
487-
except AssertionError:
488-
raise
489-
except Exception as e:
490-
# Default: capture non-assert exceptions unless explicitly disabled
491-
if os.getenv("EP_RAISE_EVAL_EXCEPTIONS", "false").strip() == "false":
492-
results = rows
493-
for row in results:
494-
row.evaluation_result = EvaluateResult(
495-
score=0.0,
496-
is_score_valid=False,
497-
reason=f"Error during evaluation: {type(e).__name__}: {e}",
498-
)
499-
if row.eval_metadata is not None:
500-
row.eval_metadata.status = Status.error(
501-
f"Error during evaluation: {type(e).__name__}: {e}",
502-
)
503-
else:
504-
raise
463+
results = await execute_pytest_with_exception_handling(
464+
test_func=test_func,
465+
evaluation_test_kwargs=evaluation_test_kwargs,
466+
processed_dataset=rows,
467+
)
505468
if not isinstance(results, list):
506469
raise ValueError(
507470
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
@@ -592,30 +555,11 @@ async def _collect_result(config, lst):
592555
run_id=run_id,
593556
rollout_ids=group_rollout_ids or None,
594557
):
595-
try:
596-
results = await execute_pytest(
597-
test_func,
598-
processed_dataset=input_dataset,
599-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
600-
)
601-
except AssertionError:
602-
raise
603-
except Exception as e:
604-
# Default: capture non-assert exceptions unless explicitly disabled
605-
if os.getenv("EP_RAISE_EVAL_EXCEPTIONS", "false").strip() == "false":
606-
results = input_dataset
607-
for row in results:
608-
row.evaluation_result = EvaluateResult(
609-
score=0.0,
610-
is_score_valid=False,
611-
reason=f"Error during evaluation: {type(e).__name__}: {e}",
612-
)
613-
if row.eval_metadata is not None:
614-
row.eval_metadata.status = Status.error(
615-
f"Error during evaluation: {type(e).__name__}: {e}",
616-
)
617-
else:
618-
raise
558+
results = await execute_pytest_with_exception_handling(
559+
test_func=test_func,
560+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
561+
processed_dataset=input_dataset,
562+
)
619563
if (
620564
results is None
621565
or not isinstance(results, list)

eval_protocol/pytest/execution.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
2+
import os
23
from collections.abc import Awaitable, Callable
3-
from typing import cast
4-
from eval_protocol.models import EvaluationRow
4+
from typing import Any, cast
5+
from eval_protocol.models import EvaluationRow, EvaluateResult, Status
56
from eval_protocol.pytest.types import Dataset, EvaluationInputParam, TestFunction
67

78

@@ -41,3 +42,70 @@ async def execute_pytest(
4142
return test_func(processed_dataset, **evaluation_test_kwargs)
4243
test_func = cast(Callable[[], EvaluationRow], test_func)
4344
return test_func(**evaluation_test_kwargs)
45+
46+
47+
async def execute_pytest_with_exception_handling(
48+
test_func: TestFunction,
49+
evaluation_test_kwargs: dict[str, Any],
50+
processed_row: EvaluationRow | None = None,
51+
processed_dataset: list[EvaluationRow] | None = None,
52+
) -> EvaluationRow | list[EvaluationRow]:
53+
"""Helper function to execute pytest with consistent exception handling.
54+
55+
Args:
56+
test_func: The test function to execute
57+
evaluation_test_kwargs: Kwargs for the evaluation function
58+
processed_row: Single row for pointwise evaluation (mutually exclusive with processed_dataset)
59+
processed_dataset: Dataset for groupwise/all evaluation (mutually exclusive with processed_row)
60+
61+
Returns:
62+
The result of execute_pytest, or the input data with error results on exception
63+
"""
64+
try:
65+
if processed_row is not None:
66+
return await execute_pytest(
67+
test_func,
68+
processed_row=processed_row,
69+
evaluation_test_kwargs=evaluation_test_kwargs,
70+
)
71+
else:
72+
return await execute_pytest(
73+
test_func,
74+
processed_dataset=processed_dataset,
75+
evaluation_test_kwargs=evaluation_test_kwargs,
76+
)
77+
except Exception as e:
78+
if os.getenv("EP_RAISE_EVAL_EXCEPTIONS", "true").strip() == "false":
79+
# Handle single row case
80+
if processed_row is not None:
81+
result = processed_row
82+
result.evaluation_result = EvaluateResult(
83+
score=0.0,
84+
is_score_valid=False,
85+
reason=f"Error during evaluation: {type(e).__name__}: {e}",
86+
)
87+
if result.eval_metadata is not None:
88+
result.eval_metadata.status = Status.error(
89+
f"Error during evaluation: {type(e).__name__}: {e}",
90+
)
91+
return result
92+
# Handle list of rows case
93+
elif processed_dataset is not None:
94+
results = processed_dataset
95+
for row in results:
96+
row.evaluation_result = EvaluateResult(
97+
score=0.0,
98+
is_score_valid=False,
99+
reason=f"Error during evaluation: {type(e).__name__}: {e}",
100+
)
101+
if row.eval_metadata is not None:
102+
row.eval_metadata.status = Status.error(
103+
f"Error during evaluation: {type(e).__name__}: {e}",
104+
)
105+
return results
106+
else:
107+
# This should never happen since one of processed_row/processed_dataset must be provided
108+
raise ValueError("Neither processed_row nor processed_dataset was provided")
109+
# Default: raise exceptions unless explicitly disabled
110+
else:
111+
raise

0 commit comments

Comments
 (0)