diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 561fb14d..a2b3882c 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -56,6 +56,7 @@ parse_ep_max_concurrent_rollouts, parse_ep_num_runs, parse_ep_completion_params, + parse_ep_passed_threshold, rollout_processor_with_retry, sanitize_filename, ) @@ -344,6 +345,7 @@ def evaluation_test( # noqa: C901 max_dataset_rows = parse_ep_max_rows(max_dataset_rows) completion_params = parse_ep_completion_params(completion_params) original_completion_params = completion_params + passed_threshold = parse_ep_passed_threshold(passed_threshold) def decorator( test_func: TestFunction, diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 3f1fb200..3c430363 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -15,6 +15,8 @@ import logging import os from typing import Optional +import json +import pathlib def pytest_addoption(parser) -> None: @@ -87,6 +89,21 @@ def pytest_addoption(parser) -> None: "Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts." ), ) + group.addoption( + "--ep-success-threshold", + action="store", + default=None, + help=("Override the success threshold for evaluation_test. Pass a float between 0.0 and 1.0 (e.g., 0.8)."), + ) + group.addoption( + "--ep-se-threshold", + action="store", + default=None, + help=( + "Override the standard error threshold for evaluation_test. " + "Pass a float >= 0.0 (e.g., 0.05). If only this is set, success threshold defaults to 0.0." + ), + ) def _normalize_max_rows(val: Optional[str]) -> Optional[str]: @@ -117,6 +134,49 @@ def _normalize_number(val: Optional[str]) -> Optional[str]: return None +def _normalize_success_threshold(val: Optional[str]) -> Optional[float]: + """Normalize success threshold value as float between 0.0 and 1.0.""" + if val is None: + return None + + try: + threshold_float = float(val.strip()) + if 0.0 <= threshold_float <= 1.0: + return threshold_float + else: + return None # threshold must be between 0 and 1 + except ValueError: + return None + + +def _normalize_se_threshold(val: Optional[str]) -> Optional[float]: + """Normalize standard error threshold value as float >= 0.0.""" + if val is None: + return None + + try: + threshold_float = float(val.strip()) + if threshold_float >= 0.0: + return threshold_float + else: + return None # standard error must be >= 0 + except ValueError: + return None + + +def _build_passed_threshold_env(success: Optional[float], se: Optional[float]) -> Optional[str]: + """Build the EP_PASSED_THRESHOLD environment variable value from the two separate thresholds.""" + if success is None and se is None: + return None + + if se is None: + return str(success) + else: + success_val = success if success is not None else 0.0 + threshold_dict = {"success": success_val, "standard_error": se} + return json.dumps(threshold_dict) + + def pytest_configure(config) -> None: # Quiet LiteLLM INFO spam early in pytest session unless user set a level try: @@ -161,11 +221,16 @@ def pytest_configure(config) -> None: if fail_on_max_retry is not None: os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry + success_threshold_val = config.getoption("--ep-success-threshold") + se_threshold_val = config.getoption("--ep-se-threshold") + norm_success = _normalize_success_threshold(success_threshold_val) + norm_se = _normalize_se_threshold(se_threshold_val) + threshold_env = _build_passed_threshold_env(norm_success, norm_se) + if threshold_env is not None: + os.environ["EP_PASSED_THRESHOLD"] = threshold_env + # Allow ad-hoc overrides of input params via CLI flags try: - import json as _json - import pathlib as _pathlib - merged: dict = {} input_params_opts = config.getoption("--ep-input-param") if input_params_opts: @@ -174,17 +239,17 @@ def pytest_configure(config) -> None: continue opt = str(opt) if opt.startswith("@"): # load JSON file - p = _pathlib.Path(opt[1:]) + p = pathlib.Path(opt[1:]) if p.is_file(): with open(p, "r", encoding="utf-8") as f: - obj = _json.load(f) + obj = json.load(f) if isinstance(obj, dict): merged.update(obj) elif "=" in opt: k, v = opt.split("=", 1) # Try parse JSON values, fallback to string try: - merged[k] = _json.loads(v) + merged[k] = json.loads(v) except Exception: merged[k] = v reasoning_effort = config.getoption("--ep-reasoning-effort") @@ -194,7 +259,7 @@ def pytest_configure(config) -> None: # Convert "none" string to None value for API compatibility eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort) if merged: - os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged) + os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged) except Exception: # best effort, do not crash pytest session pass diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index bad097a7..f4bbaebe 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -18,6 +18,7 @@ from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config import logging +import json def execute_function(func: Callable, **kwargs) -> Any: @@ -176,11 +177,9 @@ def parse_ep_completion_params(completion_params: List[CompletionParams]) -> Lis Reads the environment variable set by plugin.py and applies deep merge to each completion param. """ try: - import json as _json - _env_override = os.getenv("EP_INPUT_PARAMS_JSON") if _env_override: - override_obj = _json.loads(_env_override) + override_obj = json.loads(_env_override) if isinstance(override_obj, dict): # Apply override to each completion_params item return [deep_update_dict(dict(cp), override_obj) for cp in completion_params] @@ -189,6 +188,27 @@ def parse_ep_completion_params(completion_params: List[CompletionParams]) -> Lis return completion_params +def parse_ep_passed_threshold(default_value: Optional[Union[float, dict]]) -> Optional[Union[float, dict]]: + """Read EP_PASSED_THRESHOLD env override as float or dict. + + Assumes the environment variable was already validated by plugin.py. + Supports both float values (e.g., "0.8") and JSON dict format (e.g., '{"success":0.8}'). + """ + raw = os.getenv("EP_PASSED_THRESHOLD") + if raw is None: + return default_value + + try: + return float(raw) + except ValueError: + pass + + try: + return json.loads(raw) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"EP_PASSED_THRESHOLD env var exists but can't be parsed: {raw}") from e + + def deep_update_dict(base: dict, override: dict) -> dict: """Recursively update nested dictionaries in-place and return base.""" for key, value in override.items():