diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index bf43b7da..f65f2b0e 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -20,19 +20,6 @@ class SingleTurnRolloutProcessor(RolloutProcessor): def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: """Generate single turn rollout tasks and return them for external handling.""" - - # Quiet LiteLLM logs in test runs unless user overrode - try: - if os.environ.get("LITELLM_LOG") is None: - os.environ["LITELLM_LOG"] = "ERROR" - _llog = logging.getLogger("LiteLLM") - _llog.setLevel(logging.CRITICAL) - _llog.propagate = False - for _h in list(_llog.handlers): - _llog.removeHandler(_h) - except Exception: - pass - # Do not modify global LiteLLM cache. Disable caching per-request instead. async def process_row(row: EvaluationRow) -> EvaluationRow: @@ -48,11 +35,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: # Single-level reasoning effort: expect `reasoning_effort` only effort_val = None - if "reasoning_effort" in config.completion_params: + if ( + "reasoning_effort" in config.completion_params + and config.completion_params["reasoning_effort"] is not None + ): effort_val = str(config.completion_params["reasoning_effort"]) # flat shape elif ( isinstance(config.completion_params.get("extra_body"), dict) and "reasoning_effort" in config.completion_params["extra_body"] + and config.completion_params["extra_body"]["reasoning_effort"] is not None ): # Accept if user passed it directly inside extra_body effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index ce5f8817..95abbab9 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -47,11 +47,12 @@ aggregate, create_dynamically_parameterized_wrapper, deep_update_dict, - execute_function, extract_effort_tag, generate_parameter_combinations, log_eval_status_and_rows, parse_ep_max_rows, + parse_ep_max_concurrent_rollouts, + parse_ep_num_runs, rollout_processor_with_retry, sanitize_filename, ) @@ -331,6 +332,11 @@ def evaluation_test( # noqa: C901 active_logger: DatasetLogger = logger if logger else default_logger + # Apply override from pytest flags if present + num_runs = parse_ep_num_runs(num_runs) + max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts) + max_dataset_rows = parse_ep_max_rows(max_dataset_rows) + def decorator( test_func: TestFunction, ): @@ -478,6 +484,7 @@ def create_wrapper_with_signature() -> Callable: async def wrapper_body(**kwargs): eval_metadata = None + all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)] experiment_id = generate_id() @@ -502,10 +509,9 @@ def _log_eval_error( data_jsonl.extend(load_jsonl(p)) else: data_jsonl = load_jsonl(ds_arg) - # Apply env override for max rows if present - effective_max_rows = parse_ep_max_rows(max_dataset_rows) - if effective_max_rows is not None: - data_jsonl = data_jsonl[:effective_max_rows] + # Apply override for max rows if present + if max_dataset_rows is not None: + data_jsonl = data_jsonl[:max_dataset_rows] data = dataset_adapter(data_jsonl) elif "input_messages" in kwargs and kwargs["input_messages"] is not None: # Support either a single row (List[Message]) or many rows (List[List[Message]]) diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index 5c195f4e..2584aa50 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -109,8 +109,7 @@ def __post_init__(self): # Override backoff settings from environment variables if "EP_MAX_RETRY" in os.environ: max_retry = int(os.environ["EP_MAX_RETRY"]) - if max_retry > 0: - self.backoff_config.max_tries = max_retry + self.backoff_config.max_tries = max_retry if "EP_FAIL_ON_MAX_RETRY" in os.environ: fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower() diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 460eeb14..3f1fb200 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -28,6 +28,18 @@ def pytest_addoption(parser) -> None: "Pass an integer (e.g., 2, 50) or 'all' for no limit." ), ) + group.addoption( + "--ep-num-runs", + action="store", + default=None, + help=("Override the number of runs for evaluation_test. Pass an integer (e.g., 1, 5, 10)."), + ) + group.addoption( + "--ep-max-concurrent-rollouts", + action="store", + default=None, + help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."), + ) group.addoption( "--ep-print-summary", action="store_true", @@ -56,14 +68,13 @@ def pytest_addoption(parser) -> None: default=None, help=( "Set reasoning.effort for providers that support it (e.g., Fireworks) via LiteLLM extra_body. " - "Values: low|medium|high" + "Values: low|medium|high|none" ), ) group.addoption( "--ep-max-retry", action="store", - type=int, - default=0, + default=None, help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."), ) group.addoption( @@ -92,6 +103,20 @@ def _normalize_max_rows(val: Optional[str]) -> Optional[str]: return None +def _normalize_number(val: Optional[str]) -> Optional[str]: + if val is None: + return None + s = val.strip() + # Validate int; if invalid, ignore and return None (no override) + try: + num = int(s) + if num <= 0: + return None # num_runs must be positive + return str(num) + except ValueError: + return None + + def pytest_configure(config) -> None: # Quiet LiteLLM INFO spam early in pytest session unless user set a level try: @@ -110,6 +135,16 @@ def pytest_configure(config) -> None: if norm is not None: os.environ["EP_MAX_DATASET_ROWS"] = norm + num_runs_val = config.getoption("--ep-num-runs") + norm_runs = _normalize_number(num_runs_val) + if norm_runs is not None: + os.environ["EP_NUM_RUNS"] = norm_runs + + max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts") + norm_concurrent = _normalize_number(max_concurrent_val) + if norm_concurrent is not None: + os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent + if config.getoption("--ep-print-summary"): os.environ["EP_PRINT_SUMMARY"] = "1" @@ -118,10 +153,13 @@ def pytest_configure(config) -> None: os.environ["EP_SUMMARY_JSON"] = summary_json_path max_retry = config.getoption("--ep-max-retry") - os.environ["EP_MAX_RETRY"] = str(max_retry) + norm_max_retry = _normalize_number(max_retry) + if norm_max_retry is not None: + os.environ["EP_MAX_RETRY"] = norm_max_retry fail_on_max_retry = config.getoption("--ep-fail-on-max-retry") - os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry + if fail_on_max_retry is not None: + os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry # Allow ad-hoc overrides of input params via CLI flags try: @@ -153,7 +191,8 @@ def pytest_configure(config) -> None: if reasoning_effort: # Always place under extra_body to avoid LiteLLM rejecting top-level params eb = merged.setdefault("extra_body", {}) - eb["reasoning_effort"] = str(reasoning_effort) + # Convert "none" string to None value for API compatibility + eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort) if merged: os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged) except Exception: diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 57af60b9..498f6c22 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -18,6 +18,8 @@ ) from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config +import logging + def execute_function(func: Callable, **kwargs) -> Any: """ @@ -140,17 +142,33 @@ def log_eval_status_and_rows( def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]: - """Read EP_MAX_DATASET_ROWS env override as int or None.""" + """Read EP_MAX_DATASET_ROWS env override as int or None. + + Assumes the environment variable was already validated by plugin.py. + """ raw = os.getenv("EP_MAX_DATASET_ROWS") if raw is None: return default_value - s = raw.strip().lower() - if s == "none": - return None - try: - return int(s) - except ValueError: - return default_value + # plugin.py stores "None" as string for the "all" case + return None if raw.lower() == "none" else int(raw) + + +def parse_ep_num_runs(default_value: int) -> int: + """Read EP_NUM_RUNS env override as int. + + Assumes the environment variable was already validated by plugin.py. + """ + raw = os.getenv("EP_NUM_RUNS") + return int(raw) if raw is not None else default_value + + +def parse_ep_max_concurrent_rollouts(default_value: int) -> int: + """Read EP_MAX_CONCURRENT_ROLLOUTS env override as int. + + Assumes the environment variable was already validated by plugin.py. + """ + raw = os.getenv("EP_MAX_CONCURRENT_ROLLOUTS") + return int(raw) if raw is not None else default_value def deep_update_dict(base: dict, override: dict) -> dict: @@ -315,10 +333,14 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev return result except Exception as retry_error: # Backoff gave up + logging.error( + f"❌ Rollout failed, (retried {exception_config.backoff_config.max_tries} times): {repr(retry_error)}" + ) row.rollout_status = Status.rollout_error(str(retry_error)) return row else: # Non-retryable exception - fail immediately + logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}") row.rollout_status = Status.rollout_error(str(e)) return row