Skip to content

Commit 7b62811

Browse files
author
Dylan Huang
committed
Merge branch 'main' into remove-catch-all-except-in-agent-rollout-processor
# Conflicts: # eval_protocol/pytest/utils.py
2 parents af3e2a0 + b813ce7 commit 7b62811

File tree

5 files changed

+92
-35
lines changed

5 files changed

+92
-35
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,6 @@ class SingleTurnRolloutProcessor(RolloutProcessor):
2020

2121
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
2222
"""Generate single turn rollout tasks and return them for external handling."""
23-
24-
# Quiet LiteLLM logs in test runs unless user overrode
25-
try:
26-
if os.environ.get("LITELLM_LOG") is None:
27-
os.environ["LITELLM_LOG"] = "ERROR"
28-
_llog = logging.getLogger("LiteLLM")
29-
_llog.setLevel(logging.CRITICAL)
30-
_llog.propagate = False
31-
for _h in list(_llog.handlers):
32-
_llog.removeHandler(_h)
33-
except Exception:
34-
pass
35-
3623
# Do not modify global LiteLLM cache. Disable caching per-request instead.
3724

3825
async def process_row(row: EvaluationRow) -> EvaluationRow:
@@ -48,11 +35,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4835
# Single-level reasoning effort: expect `reasoning_effort` only
4936
effort_val = None
5037

51-
if "reasoning_effort" in config.completion_params:
38+
if (
39+
"reasoning_effort" in config.completion_params
40+
and config.completion_params["reasoning_effort"] is not None
41+
):
5242
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
5343
elif (
5444
isinstance(config.completion_params.get("extra_body"), dict)
5545
and "reasoning_effort" in config.completion_params["extra_body"]
46+
and config.completion_params["extra_body"]["reasoning_effort"] is not None
5647
):
5748
# Accept if user passed it directly inside extra_body
5849
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@
4848
aggregate,
4949
create_dynamically_parameterized_wrapper,
5050
deep_update_dict,
51-
execute_function,
5251
extract_effort_tag,
5352
generate_parameter_combinations,
5453
log_eval_status_and_rows,
5554
parse_ep_max_rows,
55+
parse_ep_max_concurrent_rollouts,
56+
parse_ep_num_runs,
5657
rollout_processor_with_retry,
5758
sanitize_filename,
5859
)
@@ -333,6 +334,11 @@ def evaluation_test( # noqa: C901
333334

334335
active_logger: DatasetLogger = logger if logger else default_logger
335336

337+
# Apply override from pytest flags if present
338+
num_runs = parse_ep_num_runs(num_runs)
339+
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
340+
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
341+
336342
def decorator(
337343
test_func: TestFunction,
338344
):
@@ -480,6 +486,7 @@ def create_wrapper_with_signature() -> Callable:
480486

481487
async def wrapper_body(**kwargs):
482488
eval_metadata = None
489+
483490
all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)]
484491

485492
experiment_id = generate_id()
@@ -502,10 +509,9 @@ def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None,
502509
data_jsonl.extend(load_jsonl(p))
503510
else:
504511
data_jsonl = load_jsonl(ds_arg)
505-
# Apply env override for max rows if present
506-
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
507-
if effective_max_rows is not None:
508-
data_jsonl = data_jsonl[:effective_max_rows]
512+
# Apply override for max rows if present
513+
if max_dataset_rows is not None:
514+
data_jsonl = data_jsonl[:max_dataset_rows]
509515
data = dataset_adapter(data_jsonl)
510516
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
511517
# Support either a single row (List[Message]) or many rows (List[List[Message]])

eval_protocol/pytest/exception_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def __post_init__(self):
109109
# Override backoff settings from environment variables
110110
if "EP_MAX_RETRY" in os.environ:
111111
max_retry = int(os.environ["EP_MAX_RETRY"])
112-
if max_retry > 0:
113-
self.backoff_config.max_tries = max_retry
112+
self.backoff_config.max_tries = max_retry
114113

115114
if "EP_FAIL_ON_MAX_RETRY" in os.environ:
116115
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()

eval_protocol/pytest/plugin.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ def pytest_addoption(parser) -> None:
2828
"Pass an integer (e.g., 2, 50) or 'all' for no limit."
2929
),
3030
)
31+
group.addoption(
32+
"--ep-num-runs",
33+
action="store",
34+
default=None,
35+
help=("Override the number of runs for evaluation_test. Pass an integer (e.g., 1, 5, 10)."),
36+
)
37+
group.addoption(
38+
"--ep-max-concurrent-rollouts",
39+
action="store",
40+
default=None,
41+
help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."),
42+
)
3143
group.addoption(
3244
"--ep-print-summary",
3345
action="store_true",
@@ -56,14 +68,13 @@ def pytest_addoption(parser) -> None:
5668
default=None,
5769
help=(
5870
"Set reasoning.effort for providers that support it (e.g., Fireworks) via LiteLLM extra_body. "
59-
"Values: low|medium|high"
71+
"Values: low|medium|high|none"
6072
),
6173
)
6274
group.addoption(
6375
"--ep-max-retry",
6476
action="store",
65-
type=int,
66-
default=0,
77+
default=None,
6778
help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."),
6879
)
6980
group.addoption(
@@ -92,6 +103,20 @@ def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
92103
return None
93104

94105

106+
def _normalize_number(val: Optional[str]) -> Optional[str]:
107+
if val is None:
108+
return None
109+
s = val.strip()
110+
# Validate int; if invalid, ignore and return None (no override)
111+
try:
112+
num = int(s)
113+
if num <= 0:
114+
return None # num_runs must be positive
115+
return str(num)
116+
except ValueError:
117+
return None
118+
119+
95120
def pytest_configure(config) -> None:
96121
# Quiet LiteLLM INFO spam early in pytest session unless user set a level
97122
try:
@@ -110,6 +135,16 @@ def pytest_configure(config) -> None:
110135
if norm is not None:
111136
os.environ["EP_MAX_DATASET_ROWS"] = norm
112137

138+
num_runs_val = config.getoption("--ep-num-runs")
139+
norm_runs = _normalize_number(num_runs_val)
140+
if norm_runs is not None:
141+
os.environ["EP_NUM_RUNS"] = norm_runs
142+
143+
max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts")
144+
norm_concurrent = _normalize_number(max_concurrent_val)
145+
if norm_concurrent is not None:
146+
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent
147+
113148
if config.getoption("--ep-print-summary"):
114149
os.environ["EP_PRINT_SUMMARY"] = "1"
115150

@@ -118,10 +153,13 @@ def pytest_configure(config) -> None:
118153
os.environ["EP_SUMMARY_JSON"] = summary_json_path
119154

120155
max_retry = config.getoption("--ep-max-retry")
121-
os.environ["EP_MAX_RETRY"] = str(max_retry)
156+
norm_max_retry = _normalize_number(max_retry)
157+
if norm_max_retry is not None:
158+
os.environ["EP_MAX_RETRY"] = norm_max_retry
122159

123160
fail_on_max_retry = config.getoption("--ep-fail-on-max-retry")
124-
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
161+
if fail_on_max_retry is not None:
162+
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
125163

126164
# Allow ad-hoc overrides of input params via CLI flags
127165
try:
@@ -153,7 +191,8 @@ def pytest_configure(config) -> None:
153191
if reasoning_effort:
154192
# Always place under extra_body to avoid LiteLLM rejecting top-level params
155193
eb = merged.setdefault("extra_body", {})
156-
eb["reasoning_effort"] = str(reasoning_effort)
194+
# Convert "none" string to None value for API compatibility
195+
eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort)
157196
if merged:
158197
os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged)
159198
except Exception:

eval_protocol/pytest/utils.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
1919

20+
import logging
21+
2022

2123
def execute_function(func: Callable, **kwargs) -> Any:
2224
"""
@@ -139,17 +141,33 @@ def log_eval_status_and_rows(
139141

140142

141143
def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]:
142-
"""Read EP_MAX_DATASET_ROWS env override as int or None."""
144+
"""Read EP_MAX_DATASET_ROWS env override as int or None.
145+
146+
Assumes the environment variable was already validated by plugin.py.
147+
"""
143148
raw = os.getenv("EP_MAX_DATASET_ROWS")
144149
if raw is None:
145150
return default_value
146-
s = raw.strip().lower()
147-
if s == "none":
148-
return None
149-
try:
150-
return int(s)
151-
except ValueError:
152-
return default_value
151+
# plugin.py stores "None" as string for the "all" case
152+
return None if raw.lower() == "none" else int(raw)
153+
154+
155+
def parse_ep_num_runs(default_value: int) -> int:
156+
"""Read EP_NUM_RUNS env override as int.
157+
158+
Assumes the environment variable was already validated by plugin.py.
159+
"""
160+
raw = os.getenv("EP_NUM_RUNS")
161+
return int(raw) if raw is not None else default_value
162+
163+
164+
def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
165+
"""Read EP_MAX_CONCURRENT_ROLLOUTS env override as int.
166+
167+
Assumes the environment variable was already validated by plugin.py.
168+
"""
169+
raw = os.getenv("EP_MAX_CONCURRENT_ROLLOUTS")
170+
return int(raw) if raw is not None else default_value
153171

154172

155173
def deep_update_dict(base: dict, override: dict) -> dict:
@@ -314,10 +332,14 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
314332
return result
315333
except Exception as retry_error:
316334
# Backoff gave up
335+
logging.error(
336+
f"❌ Rollout failed, (retried {exception_config.backoff_config.max_tries} times): {repr(retry_error)}"
337+
)
317338
row.rollout_status = Status.rollout_error(str(retry_error))
318339
return row
319340
else:
320341
# Non-retryable exception - fail immediately
342+
logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}")
321343
row.rollout_status = Status.rollout_error(repr(e))
322344
return row
323345

0 commit comments

Comments
 (0)