Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,6 @@ class SingleTurnRolloutProcessor(RolloutProcessor):

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Generate single turn rollout tasks and return them for external handling."""

# Quiet LiteLLM logs in test runs unless user overrode
try:
if os.environ.get("LITELLM_LOG") is None:
os.environ["LITELLM_LOG"] = "ERROR"
_llog = logging.getLogger("LiteLLM")
_llog.setLevel(logging.CRITICAL)
_llog.propagate = False
for _h in list(_llog.handlers):
_llog.removeHandler(_h)
except Exception:
pass

# Do not modify global LiteLLM cache. Disable caching per-request instead.

async def process_row(row: EvaluationRow) -> EvaluationRow:
Expand All @@ -48,11 +35,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
# Single-level reasoning effort: expect `reasoning_effort` only
effort_val = None

if "reasoning_effort" in config.completion_params:
if (
"reasoning_effort" in config.completion_params
and config.completion_params["reasoning_effort"] is not None
):
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
elif (
isinstance(config.completion_params.get("extra_body"), dict)
and "reasoning_effort" in config.completion_params["extra_body"]
and config.completion_params["extra_body"]["reasoning_effort"] is not None
):
# Accept if user passed it directly inside extra_body
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
Expand Down
16 changes: 11 additions & 5 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@
aggregate,
create_dynamically_parameterized_wrapper,
deep_update_dict,
execute_function,
extract_effort_tag,
generate_parameter_combinations,
log_eval_status_and_rows,
parse_ep_max_rows,
parse_ep_max_concurrent_rollouts,
parse_ep_num_runs,
rollout_processor_with_retry,
sanitize_filename,
)
Expand Down Expand Up @@ -331,6 +332,11 @@ def evaluation_test( # noqa: C901

active_logger: DatasetLogger = logger if logger else default_logger

# Apply override from pytest flags if present
num_runs = parse_ep_num_runs(num_runs)
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)

def decorator(
test_func: TestFunction,
):
Expand Down Expand Up @@ -478,6 +484,7 @@ def create_wrapper_with_signature() -> Callable:

async def wrapper_body(**kwargs):
eval_metadata = None

all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)]

experiment_id = generate_id()
Expand All @@ -502,10 +509,9 @@ def _log_eval_error(
data_jsonl.extend(load_jsonl(p))
else:
data_jsonl = load_jsonl(ds_arg)
# Apply env override for max rows if present
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
if effective_max_rows is not None:
data_jsonl = data_jsonl[:effective_max_rows]
# Apply override for max rows if present
if max_dataset_rows is not None:
data_jsonl = data_jsonl[:max_dataset_rows]
data = dataset_adapter(data_jsonl)
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
# Support either a single row (List[Message]) or many rows (List[List[Message]])
Expand Down
3 changes: 1 addition & 2 deletions eval_protocol/pytest/exception_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def __post_init__(self):
# Override backoff settings from environment variables
if "EP_MAX_RETRY" in os.environ:
max_retry = int(os.environ["EP_MAX_RETRY"])
if max_retry > 0:
self.backoff_config.max_tries = max_retry
self.backoff_config.max_tries = max_retry

if "EP_FAIL_ON_MAX_RETRY" in os.environ:
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()
Expand Down
51 changes: 45 additions & 6 deletions eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def pytest_addoption(parser) -> None:
"Pass an integer (e.g., 2, 50) or 'all' for no limit."
),
)
group.addoption(
"--ep-num-runs",
action="store",
default=None,
help=("Override the number of runs for evaluation_test. Pass an integer (e.g., 1, 5, 10)."),
)
group.addoption(
"--ep-max-concurrent-rollouts",
action="store",
default=None,
help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."),
)
group.addoption(
"--ep-print-summary",
action="store_true",
Expand Down Expand Up @@ -56,14 +68,13 @@ def pytest_addoption(parser) -> None:
default=None,
help=(
"Set reasoning.effort for providers that support it (e.g., Fireworks) via LiteLLM extra_body. "
"Values: low|medium|high"
"Values: low|medium|high|none"
),
)
group.addoption(
"--ep-max-retry",
action="store",
type=int,
default=0,
default=None,
help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."),
)
group.addoption(
Expand Down Expand Up @@ -92,6 +103,20 @@ def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
return None


def _normalize_number(val: Optional[str]) -> Optional[str]:
if val is None:
return None
s = val.strip()
# Validate int; if invalid, ignore and return None (no override)
try:
num = int(s)
if num <= 0:
return None # num_runs must be positive
return str(num)
except ValueError:
return None


def pytest_configure(config) -> None:
# Quiet LiteLLM INFO spam early in pytest session unless user set a level
try:
Expand All @@ -110,6 +135,16 @@ def pytest_configure(config) -> None:
if norm is not None:
os.environ["EP_MAX_DATASET_ROWS"] = norm

num_runs_val = config.getoption("--ep-num-runs")
norm_runs = _normalize_number(num_runs_val)
if norm_runs is not None:
os.environ["EP_NUM_RUNS"] = norm_runs

max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts")
norm_concurrent = _normalize_number(max_concurrent_val)
if norm_concurrent is not None:
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent

if config.getoption("--ep-print-summary"):
os.environ["EP_PRINT_SUMMARY"] = "1"

Expand All @@ -118,10 +153,13 @@ def pytest_configure(config) -> None:
os.environ["EP_SUMMARY_JSON"] = summary_json_path

max_retry = config.getoption("--ep-max-retry")
os.environ["EP_MAX_RETRY"] = str(max_retry)
norm_max_retry = _normalize_number(max_retry)
if norm_max_retry is not None:
os.environ["EP_MAX_RETRY"] = norm_max_retry

fail_on_max_retry = config.getoption("--ep-fail-on-max-retry")
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
if fail_on_max_retry is not None:
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry

# Allow ad-hoc overrides of input params via CLI flags
try:
Expand Down Expand Up @@ -153,7 +191,8 @@ def pytest_configure(config) -> None:
if reasoning_effort:
# Always place under extra_body to avoid LiteLLM rejecting top-level params
eb = merged.setdefault("extra_body", {})
eb["reasoning_effort"] = str(reasoning_effort)
# Convert "none" string to None value for API compatibility
eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort)
if merged:
os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged)
except Exception:
Expand Down
38 changes: 30 additions & 8 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config

import logging


def execute_function(func: Callable, **kwargs) -> Any:
"""
Expand Down Expand Up @@ -140,17 +142,33 @@ def log_eval_status_and_rows(


def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]:
"""Read EP_MAX_DATASET_ROWS env override as int or None."""
"""Read EP_MAX_DATASET_ROWS env override as int or None.

Assumes the environment variable was already validated by plugin.py.
"""
raw = os.getenv("EP_MAX_DATASET_ROWS")
if raw is None:
return default_value
s = raw.strip().lower()
if s == "none":
return None
try:
return int(s)
except ValueError:
return default_value
# plugin.py stores "None" as string for the "all" case
return None if raw.lower() == "none" else int(raw)


def parse_ep_num_runs(default_value: int) -> int:
"""Read EP_NUM_RUNS env override as int.

Assumes the environment variable was already validated by plugin.py.
"""
raw = os.getenv("EP_NUM_RUNS")
return int(raw) if raw is not None else default_value


def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
"""Read EP_MAX_CONCURRENT_ROLLOUTS env override as int.

Assumes the environment variable was already validated by plugin.py.
"""
raw = os.getenv("EP_MAX_CONCURRENT_ROLLOUTS")
return int(raw) if raw is not None else default_value


def deep_update_dict(base: dict, override: dict) -> dict:
Expand Down Expand Up @@ -315,10 +333,14 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
return result
except Exception as retry_error:
# Backoff gave up
logging.error(
f"❌ Rollout failed, (retried {exception_config.backoff_config.max_tries} times): {repr(retry_error)}"
)
row.rollout_status = Status.rollout_error(str(retry_error))
return row
else:
# Non-retryable exception - fail immediately
logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}")
row.rollout_status = Status.rollout_error(str(e))
return row

Expand Down
Loading