Skip to content

Commit b813ce7

Browse files
authored
Fix Plugin (#113)
* Add all options to plugin * first pass * almost finished * remove comment * fix
1 parent 3d07d2e commit b813ce7

File tree

5 files changed

+92
-35
lines changed

5 files changed

+92
-35
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,6 @@ class SingleTurnRolloutProcessor(RolloutProcessor):
2020

2121
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
2222
"""Generate single turn rollout tasks and return them for external handling."""
23-
24-
# Quiet LiteLLM logs in test runs unless user overrode
25-
try:
26-
if os.environ.get("LITELLM_LOG") is None:
27-
os.environ["LITELLM_LOG"] = "ERROR"
28-
_llog = logging.getLogger("LiteLLM")
29-
_llog.setLevel(logging.CRITICAL)
30-
_llog.propagate = False
31-
for _h in list(_llog.handlers):
32-
_llog.removeHandler(_h)
33-
except Exception:
34-
pass
35-
3623
# Do not modify global LiteLLM cache. Disable caching per-request instead.
3724

3825
async def process_row(row: EvaluationRow) -> EvaluationRow:
@@ -48,11 +35,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4835
# Single-level reasoning effort: expect `reasoning_effort` only
4936
effort_val = None
5037

51-
if "reasoning_effort" in config.completion_params:
38+
if (
39+
"reasoning_effort" in config.completion_params
40+
and config.completion_params["reasoning_effort"] is not None
41+
):
5242
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
5343
elif (
5444
isinstance(config.completion_params.get("extra_body"), dict)
5545
and "reasoning_effort" in config.completion_params["extra_body"]
46+
and config.completion_params["extra_body"]["reasoning_effort"] is not None
5647
):
5748
# Accept if user passed it directly inside extra_body
5849
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@
4747
aggregate,
4848
create_dynamically_parameterized_wrapper,
4949
deep_update_dict,
50-
execute_function,
5150
extract_effort_tag,
5251
generate_parameter_combinations,
5352
log_eval_status_and_rows,
5453
parse_ep_max_rows,
54+
parse_ep_max_concurrent_rollouts,
55+
parse_ep_num_runs,
5556
rollout_processor_with_retry,
5657
sanitize_filename,
5758
)
@@ -331,6 +332,11 @@ def evaluation_test( # noqa: C901
331332

332333
active_logger: DatasetLogger = logger if logger else default_logger
333334

335+
# Apply override from pytest flags if present
336+
num_runs = parse_ep_num_runs(num_runs)
337+
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
338+
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
339+
334340
def decorator(
335341
test_func: TestFunction,
336342
):
@@ -478,6 +484,7 @@ def create_wrapper_with_signature() -> Callable:
478484

479485
async def wrapper_body(**kwargs):
480486
eval_metadata = None
487+
481488
all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)]
482489

483490
experiment_id = generate_id()
@@ -502,10 +509,9 @@ def _log_eval_error(
502509
data_jsonl.extend(load_jsonl(p))
503510
else:
504511
data_jsonl = load_jsonl(ds_arg)
505-
# Apply env override for max rows if present
506-
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
507-
if effective_max_rows is not None:
508-
data_jsonl = data_jsonl[:effective_max_rows]
512+
# Apply override for max rows if present
513+
if max_dataset_rows is not None:
514+
data_jsonl = data_jsonl[:max_dataset_rows]
509515
data = dataset_adapter(data_jsonl)
510516
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
511517
# Support either a single row (List[Message]) or many rows (List[List[Message]])

eval_protocol/pytest/exception_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def __post_init__(self):
109109
# Override backoff settings from environment variables
110110
if "EP_MAX_RETRY" in os.environ:
111111
max_retry = int(os.environ["EP_MAX_RETRY"])
112-
if max_retry > 0:
113-
self.backoff_config.max_tries = max_retry
112+
self.backoff_config.max_tries = max_retry
114113

115114
if "EP_FAIL_ON_MAX_RETRY" in os.environ:
116115
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()

eval_protocol/pytest/plugin.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ def pytest_addoption(parser) -> None:
2828
"Pass an integer (e.g., 2, 50) or 'all' for no limit."
2929
),
3030
)
31+
group.addoption(
32+
"--ep-num-runs",
33+
action="store",
34+
default=None,
35+
help=("Override the number of runs for evaluation_test. Pass an integer (e.g., 1, 5, 10)."),
36+
)
37+
group.addoption(
38+
"--ep-max-concurrent-rollouts",
39+
action="store",
40+
default=None,
41+
help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."),
42+
)
3143
group.addoption(
3244
"--ep-print-summary",
3345
action="store_true",
@@ -56,14 +68,13 @@ def pytest_addoption(parser) -> None:
5668
default=None,
5769
help=(
5870
"Set reasoning.effort for providers that support it (e.g., Fireworks) via LiteLLM extra_body. "
59-
"Values: low|medium|high"
71+
"Values: low|medium|high|none"
6072
),
6173
)
6274
group.addoption(
6375
"--ep-max-retry",
6476
action="store",
65-
type=int,
66-
default=0,
77+
default=None,
6778
help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."),
6879
)
6980
group.addoption(
@@ -92,6 +103,20 @@ def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
92103
return None
93104

94105

106+
def _normalize_number(val: Optional[str]) -> Optional[str]:
107+
if val is None:
108+
return None
109+
s = val.strip()
110+
# Validate int; if invalid, ignore and return None (no override)
111+
try:
112+
num = int(s)
113+
if num <= 0:
114+
return None # num_runs must be positive
115+
return str(num)
116+
except ValueError:
117+
return None
118+
119+
95120
def pytest_configure(config) -> None:
96121
# Quiet LiteLLM INFO spam early in pytest session unless user set a level
97122
try:
@@ -110,6 +135,16 @@ def pytest_configure(config) -> None:
110135
if norm is not None:
111136
os.environ["EP_MAX_DATASET_ROWS"] = norm
112137

138+
num_runs_val = config.getoption("--ep-num-runs")
139+
norm_runs = _normalize_number(num_runs_val)
140+
if norm_runs is not None:
141+
os.environ["EP_NUM_RUNS"] = norm_runs
142+
143+
max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts")
144+
norm_concurrent = _normalize_number(max_concurrent_val)
145+
if norm_concurrent is not None:
146+
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent
147+
113148
if config.getoption("--ep-print-summary"):
114149
os.environ["EP_PRINT_SUMMARY"] = "1"
115150

@@ -118,10 +153,13 @@ def pytest_configure(config) -> None:
118153
os.environ["EP_SUMMARY_JSON"] = summary_json_path
119154

120155
max_retry = config.getoption("--ep-max-retry")
121-
os.environ["EP_MAX_RETRY"] = str(max_retry)
156+
norm_max_retry = _normalize_number(max_retry)
157+
if norm_max_retry is not None:
158+
os.environ["EP_MAX_RETRY"] = norm_max_retry
122159

123160
fail_on_max_retry = config.getoption("--ep-fail-on-max-retry")
124-
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
161+
if fail_on_max_retry is not None:
162+
os.environ["EP_FAIL_ON_MAX_RETRY"] = fail_on_max_retry
125163

126164
# Allow ad-hoc overrides of input params via CLI flags
127165
try:
@@ -153,7 +191,8 @@ def pytest_configure(config) -> None:
153191
if reasoning_effort:
154192
# Always place under extra_body to avoid LiteLLM rejecting top-level params
155193
eb = merged.setdefault("extra_body", {})
156-
eb["reasoning_effort"] = str(reasoning_effort)
194+
# Convert "none" string to None value for API compatibility
195+
eb["reasoning_effort"] = None if reasoning_effort.lower() == "none" else str(reasoning_effort)
157196
if merged:
158197
os.environ["EP_INPUT_PARAMS_JSON"] = _json.dumps(merged)
159198
except Exception:

eval_protocol/pytest/utils.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
)
1919
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
2020

21+
import logging
22+
2123

2224
def execute_function(func: Callable, **kwargs) -> Any:
2325
"""
@@ -140,17 +142,33 @@ def log_eval_status_and_rows(
140142

141143

142144
def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]:
143-
"""Read EP_MAX_DATASET_ROWS env override as int or None."""
145+
"""Read EP_MAX_DATASET_ROWS env override as int or None.
146+
147+
Assumes the environment variable was already validated by plugin.py.
148+
"""
144149
raw = os.getenv("EP_MAX_DATASET_ROWS")
145150
if raw is None:
146151
return default_value
147-
s = raw.strip().lower()
148-
if s == "none":
149-
return None
150-
try:
151-
return int(s)
152-
except ValueError:
153-
return default_value
152+
# plugin.py stores "None" as string for the "all" case
153+
return None if raw.lower() == "none" else int(raw)
154+
155+
156+
def parse_ep_num_runs(default_value: int) -> int:
157+
"""Read EP_NUM_RUNS env override as int.
158+
159+
Assumes the environment variable was already validated by plugin.py.
160+
"""
161+
raw = os.getenv("EP_NUM_RUNS")
162+
return int(raw) if raw is not None else default_value
163+
164+
165+
def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
166+
"""Read EP_MAX_CONCURRENT_ROLLOUTS env override as int.
167+
168+
Assumes the environment variable was already validated by plugin.py.
169+
"""
170+
raw = os.getenv("EP_MAX_CONCURRENT_ROLLOUTS")
171+
return int(raw) if raw is not None else default_value
154172

155173

156174
def deep_update_dict(base: dict, override: dict) -> dict:
@@ -315,10 +333,14 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
315333
return result
316334
except Exception as retry_error:
317335
# Backoff gave up
336+
logging.error(
337+
f"❌ Rollout failed, (retried {exception_config.backoff_config.max_tries} times): {repr(retry_error)}"
338+
)
318339
row.rollout_status = Status.rollout_error(str(retry_error))
319340
return row
320341
else:
321342
# Non-retryable exception - fail immediately
343+
logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}")
322344
row.rollout_status = Status.rollout_error(str(e))
323345
return row
324346

0 commit comments

Comments
 (0)