Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
parse_ep_max_concurrent_rollouts,
parse_ep_num_runs,
parse_ep_completion_params,
parse_ep_passed_threshold,
rollout_processor_with_retry,
sanitize_filename,
)
Expand Down Expand Up @@ -344,6 +345,7 @@ def evaluation_test( # noqa: C901
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
completion_params = parse_ep_completion_params(completion_params)
original_completion_params = completion_params
passed_threshold = parse_ep_passed_threshold(passed_threshold)

def decorator(
test_func: TestFunction,
Expand Down
79 changes: 72 additions & 7 deletions eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import logging
import os
from typing import Optional
import json
import pathlib


def pytest_addoption(parser) -> None:
Expand Down Expand Up @@ -87,6 +89,21 @@ def pytest_addoption(parser) -> None:
"Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts."
),
)
group.addoption(
"--ep-success-threshold",
action="store",
default=None,
help=("Override the success threshold for evaluation_test. Pass a float between 0.0 and 1.0 (e.g., 0.8)."),
)
group.addoption(
"--ep-se-threshold",
action="store",
default=None,
help=(
"Override the standard error threshold for evaluation_test. "
"Pass a float >= 0.0 (e.g., 0.05). If only this is set, success threshold defaults to 0.0."
),
)


def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -117,6 +134,49 @@ def _normalize_number(val: Optional[str]) -> Optional[str]:
return None


def _normalize_success_threshold(val: Optional[str]) -> Optional[float]:
"""Normalize success threshold value as float between 0.0 and 1.0."""
if val is None:
return None

try:
threshold_float = float(val.strip())
if 0.0 <= threshold_float <= 1.0:
return threshold_float
else:
return None # threshold must be between 0 and 1
except ValueError:
return None


def _normalize_se_threshold(val: Optional[str]) -> Optional[float]:
"""Normalize standard error threshold value as float >= 0.0."""
if val is None:
return None

try:
threshold_float = float(val.strip())
if threshold_float >= 0.0:
return threshold_float
else:
return None # standard error must be >= 0
except ValueError:
return None


def _build_passed_threshold_env(success: Optional[float], se: Optional[float]) -> Optional[str]:
"""Build the EP_PASSED_THRESHOLD environment variable value from the two separate thresholds."""
if success is None and se is None:
return None

if se is None:
return str(success)
else:
success_val = success if success is not None else 0.0
threshold_dict = {"success": success_val, "standard_error": se}
return json.dumps(threshold_dict)


def pytest_configure(config) -> None:
# Quiet LiteLLM INFO spam early in pytest session unless user set a level
try:
Expand Down Expand Up @@ -161,11 +221,16 @@ def pytest_configure(config) -> None:
if fail_on_max_retry is not None:
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry

success_threshold_val = config.getoption("--ep-success-threshold")
se_threshold_val = config.getoption("--ep-se-threshold")
norm_success = _normalize_success_threshold(success_threshold_val)
norm_se = _normalize_se_threshold(se_threshold_val)
threshold_env = _build_passed_threshold_env(norm_success, norm_se)
if threshold_env is not None:
os.environ["EP_PASSED_THRESHOLD"] = threshold_env

# Allow ad-hoc overrides of input params via CLI flags
try:
import json as _json
import pathlib as _pathlib

merged: dict = {}
input_params_opts = config.getoption("--ep-input-param")
if input_params_opts:
Expand All @@ -174,17 +239,17 @@ def pytest_configure(config) -> None:
continue
opt = str(opt)
if opt.startswith("@"): # load JSON file
p = _pathlib.Path(opt[1:])
p = pathlib.Path(opt[1:])
if p.is_file():
with open(p, "r", encoding="utf-8") as f:
obj = _json.load(f)
obj = json.load(f)
if isinstance(obj, dict):
merged.update(obj)
elif "=" in opt:
k, v = opt.split("=", 1)
# Try parse JSON values, fallback to string
try:
merged[k] = _json.loads(v)
merged[k] = json.loads(v)
except Exception:
merged[k] = v
reasoning_effort = config.getoption("--ep-reasoning-effort")
Expand All @@ -194,7 +259,7 @@ def pytest_configure(config) -> None:
# Convert "none" string to None value for API compatibility
eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort)
if merged:
os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged)
os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged)
except Exception:
# best effort, do not crash pytest session
pass
26 changes: 23 additions & 3 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config

import logging
import json


def execute_function(func: Callable, **kwargs) -> Any:
Expand Down Expand Up @@ -176,11 +177,9 @@ def parse_ep_completion_params(completion_params: List[CompletionParams]) -> Lis
Reads the environment variable set by plugin.py and applies deep merge to each completion param.
"""
try:
import json as _json

_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
if _env_override:
override_obj = _json.loads(_env_override)
override_obj = json.loads(_env_override)
if isinstance(override_obj, dict):
# Apply override to each completion_params item
return [deep_update_dict(dict(cp), override_obj) for cp in completion_params]
Expand All @@ -189,6 +188,27 @@ def parse_ep_completion_params(completion_params: List[CompletionParams]) -> Lis
return completion_params


def parse_ep_passed_threshold(default_value: Optional[Union[float, dict]]) -> Optional[Union[float, dict]]:
"""Read EP_PASSED_THRESHOLD env override as float or dict.

Assumes the environment variable was already validated by plugin.py.
Supports both float values (e.g., "0.8") and JSON dict format (e.g., '{"success":0.8}').
"""
raw = os.getenv("EP_PASSED_THRESHOLD")
if raw is None:
return default_value

try:
return float(raw)
except ValueError:
pass

try:
return json.loads(raw)
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"EP_PASSED_THRESHOLD env var exists but can't be parsed: {raw}") from e


def deep_update_dict(base: dict, override: dict) -> dict:
"""Recursively update nested dictionaries in-place and return base."""
for key, value in override.items():
Expand Down
Loading