Skip to content

Commit 795072e

Browse files
authored
default 96 96 (#393)
1 parent cf4fc4e commit 795072e

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def evaluation_test(
9393
filtered_row_ids: Sequence[str] | None = None,
9494
max_dataset_rows: int | None = None,
9595
mcp_config_path: str | None = None,
96-
max_concurrent_rollouts: int = 8,
97-
max_concurrent_evaluations: int = 64,
96+
max_concurrent_rollouts: int = 96,
97+
max_concurrent_evaluations: int = 96,
9898
server_script_path: str | None = None,
9999
steps: int = 30,
100100
mode: EvaluationTestMode = "pointwise",
@@ -409,21 +409,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
409409

410410
rollout_processor.setup()
411411

412-
use_priority_scheduler = (
413-
(
414-
os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1"
415-
and not isinstance(rollout_processor, MCPGymRolloutProcessor)
416-
)
417-
)
412+
use_priority_scheduler = os.environ.get(
413+
"EP_USE_PRIORITY_SCHEDULER", "0"
414+
) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor)
418415

419416
if use_priority_scheduler:
420417
microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None)
421418
output_dir = os.environ.get("EP_OUTPUT_DIR", None)
422419
if microbatch_output_size and output_dir:
423-
output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
420+
output_buffer = MicroBatchDataBuffer(
421+
num_runs=num_runs,
422+
batch_size=int(microbatch_output_size),
423+
output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"),
424+
)
424425
else:
425426
output_buffer = None
426-
427+
427428
try:
428429
priority_results = await execute_priority_rollouts(
429430
dataset=data,
@@ -441,12 +442,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
441442
finally:
442443
if output_buffer:
443444
await output_buffer.close()
444-
445+
445446
for res in priority_results:
446447
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
447448
if run_idx < len(all_results):
448449
all_results[run_idx].append(res)
449-
450+
450451
processed_rows_in_run.append(res)
451452

452453
postprocess(
@@ -462,6 +463,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
462463
)
463464

464465
else:
466+
465467
async def execute_run(run_idx: int, config: RolloutProcessorConfig):
466468
nonlocal all_results
467469

@@ -506,9 +508,7 @@ async def _execute_pointwise_eval_with_semaphore(
506508
raise ValueError(
507509
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
508510
)
509-
result.execution_metadata.eval_duration_seconds = (
510-
time.perf_counter() - start_time
511-
)
511+
result.execution_metadata.eval_duration_seconds = time.perf_counter() - start_time
512512
return result
513513

514514
async def _execute_groupwise_eval_with_semaphore(
@@ -519,7 +519,9 @@ async def _execute_groupwise_eval_with_semaphore(
519519
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
520520
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
521521
group_rollout_ids = [
522-
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
522+
r.execution_metadata.rollout_id
523+
for r in rows
524+
if r.execution_metadata.rollout_id
523525
]
524526
async with rollout_logging_context(
525527
primary_rollout_id or "",
@@ -596,7 +598,9 @@ async def _collect_result(config, lst):
596598
row_groups[row.input_metadata.row_id].append(row)
597599
tasks = []
598600
for _, rows in row_groups.items():
599-
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
601+
tasks.append(
602+
asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))
603+
)
600604
results = []
601605
for task in tasks:
602606
res = await task
@@ -692,9 +696,9 @@ async def _collect_result(config, lst):
692696
# For other processors, create all tasks at once and run in parallel
693697
# Concurrency is now controlled by the shared semaphore in each rollout processor
694698
await run_tasks_with_run_progress(execute_run, num_runs, config)
695-
699+
696700
experiment_duration_seconds = time.perf_counter() - experiment_start_time
697-
701+
698702
# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
699703
# rollout_id is used to differentiate the result from different completion_params
700704
if mode == "groupwise":
@@ -730,15 +734,12 @@ async def _collect_result(config, lst):
730734
experiment_duration_seconds,
731735
)
732736

733-
734-
735737
if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
736738
raise AssertionError(
737739
"Some EvaluationRow instances are missing evaluation_result. "
738740
"Your @evaluation_test function must set `row.evaluation_result`"
739741
)
740742

741-
742743
except AssertionError:
743744
_log_eval_error(
744745
Status.eval_finished(),

0 commit comments

Comments
 (0)