Skip to content

Commit e4ebf71

Browse files
committed
ep max evals flag
1 parent 8349922 commit e4ebf71

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
log_eval_status_and_rows,
5858
parse_ep_completion_params,
5959
parse_ep_completion_params_overwrite,
60+
parse_ep_max_concurrent_evaluations,
6061
parse_ep_max_concurrent_rollouts,
6162
parse_ep_max_rows,
6263
parse_ep_num_runs,
@@ -201,6 +202,7 @@ def evaluation_test(
201202
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
202203
num_runs = parse_ep_num_runs(num_runs)
203204
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
205+
max_concurrent_evaluations = parse_ep_max_concurrent_evaluations(max_concurrent_evaluations)
204206
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
205207
completion_params = parse_ep_completion_params(completion_params)
206208
completion_params = parse_ep_completion_params_overwrite(completion_params)

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,15 @@ def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
226226
return int(raw) if raw is not None else default_value
227227

228228

229+
def parse_ep_max_concurrent_evaluations(default_value: int) -> int:
230+
"""Read EP_MAX_CONCURRENT_EVALUATIONS env override as int.
231+
232+
Assumes the environment variable was already validated by plugin.py.
233+
"""
234+
raw = os.getenv("EP_MAX_CONCURRENT_EVALUATIONS")
235+
return int(raw) if raw is not None else default_value
236+
237+
229238
def parse_ep_completion_params(
230239
completion_params: Sequence[CompletionParams | None] | None,
231240
) -> Sequence[CompletionParams | None]:

eval_protocol/pytest/plugin.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def pytest_addoption(parser) -> None:
4545
default=None,
4646
help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."),
4747
)
48+
group.addoption(
49+
"--ep-max-concurrent-evaluations",
50+
action="store",
51+
default=None,
52+
help=("Override the maximum number of concurrent evaluations. Pass an integer (e.g., 8, 50, 100)."),
53+
)
4854
group.addoption(
4955
"--ep-print-summary",
5056
action="store_true",
@@ -242,10 +248,15 @@ def pytest_configure(config) -> None:
242248
if norm_runs is not None:
243249
os.environ["EP_NUM_RUNS"] = norm_runs
244250

245-
max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts")
246-
norm_concurrent = _normalize_number(max_concurrent_val)
247-
if norm_concurrent is not None:
248-
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent
251+
max_concurrent_rollouts_val = config.getoption("--ep-max-concurrent-rollouts")
252+
norm_concurrent_rollouts = _normalize_number(max_concurrent_rollouts_val)
253+
if norm_concurrent_rollouts is not None:
254+
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent_rollouts
255+
256+
max_concurrent_evals_val = config.getoption("--ep-max-concurrent-evaluations")
257+
norm_concurrent_evals = _normalize_number(max_concurrent_evals_val)
258+
if norm_concurrent_evals is not None:
259+
os.environ["EP_MAX_CONCURRENT_EVALUATIONS"] = norm_concurrent_evals
249260

250261
if config.getoption("--ep-print-summary"):
251262
os.environ["EP_PRINT_SUMMARY"] = "1"

0 commit comments

Comments
 (0)