@@ -93,8 +93,8 @@ def evaluation_test(
9393 filtered_row_ids : Sequence [str ] | None = None ,
9494 max_dataset_rows : int | None = None ,
9595 mcp_config_path : str | None = None ,
96- max_concurrent_rollouts : int = 8 ,
97- max_concurrent_evaluations : int = 64 ,
96+ max_concurrent_rollouts : int = 96 ,
97+ max_concurrent_evaluations : int = 96 ,
9898 server_script_path : str | None = None ,
9999 steps : int = 30 ,
100100 mode : EvaluationTestMode = "pointwise" ,
@@ -409,21 +409,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
409409
410410 rollout_processor .setup ()
411411
412- use_priority_scheduler = (
413- (
414- os .environ .get ("EP_USE_PRIORITY_SCHEDULER" , "0" ) == "1"
415- and not isinstance (rollout_processor , MCPGymRolloutProcessor )
416- )
417- )
412+ use_priority_scheduler = os .environ .get (
413+ "EP_USE_PRIORITY_SCHEDULER" , "0"
414+ ) == "1" and not isinstance (rollout_processor , MCPGymRolloutProcessor )
418415
419416 if use_priority_scheduler :
420417 microbatch_output_size = os .environ .get ("EP_MICRO_BATCH_OUTPUT_SIZE" , None )
421418 output_dir = os .environ .get ("EP_OUTPUT_DIR" , None )
422419 if microbatch_output_size and output_dir :
423- output_buffer = MicroBatchDataBuffer (num_runs = num_runs , batch_size = int (microbatch_output_size ), output_path_template = os .path .join (output_dir , "buffer_{index}.jsonl" ))
420+ output_buffer = MicroBatchDataBuffer (
421+ num_runs = num_runs ,
422+ batch_size = int (microbatch_output_size ),
423+ output_path_template = os .path .join (output_dir , "buffer_{index}.jsonl" ),
424+ )
424425 else :
425426 output_buffer = None
426-
427+
427428 try :
428429 priority_results = await execute_priority_rollouts (
429430 dataset = data ,
@@ -441,12 +442,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
441442 finally :
442443 if output_buffer :
443444 await output_buffer .close ()
444-
445+
445446 for res in priority_results :
446447 run_idx = (res .execution_metadata .extra or {}).get ("run_index" , 0 )
447448 if run_idx < len (all_results ):
448449 all_results [run_idx ].append (res )
449-
450+
450451 processed_rows_in_run .append (res )
451452
452453 postprocess (
@@ -462,6 +463,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
462463 )
463464
464465 else :
466+
465467 async def execute_run (run_idx : int , config : RolloutProcessorConfig ):
466468 nonlocal all_results
467469
@@ -506,9 +508,7 @@ async def _execute_pointwise_eval_with_semaphore(
506508 raise ValueError (
507509 f"Test function { test_func .__name__ } did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
508510 )
509- result .execution_metadata .eval_duration_seconds = (
510- time .perf_counter () - start_time
511- )
511+ result .execution_metadata .eval_duration_seconds = time .perf_counter () - start_time
512512 return result
513513
514514 async def _execute_groupwise_eval_with_semaphore (
@@ -519,7 +519,9 @@ async def _execute_groupwise_eval_with_semaphore(
519519 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {}
520520 primary_rollout_id = rows [0 ].execution_metadata .rollout_id if rows else None
521521 group_rollout_ids = [
522- r .execution_metadata .rollout_id for r in rows if r .execution_metadata .rollout_id
522+ r .execution_metadata .rollout_id
523+ for r in rows
524+ if r .execution_metadata .rollout_id
523525 ]
524526 async with rollout_logging_context (
525527 primary_rollout_id or "" ,
@@ -596,7 +598,9 @@ async def _collect_result(config, lst):
596598 row_groups [row .input_metadata .row_id ].append (row )
597599 tasks = []
598600 for _ , rows in row_groups .items ():
599- tasks .append (asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows )))
601+ tasks .append (
602+ asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows ))
603+ )
600604 results = []
601605 for task in tasks :
602606 res = await task
@@ -692,9 +696,9 @@ async def _collect_result(config, lst):
692696 # For other processors, create all tasks at once and run in parallel
693697 # Concurrency is now controlled by the shared semaphore in each rollout processor
694698 await run_tasks_with_run_progress (execute_run , num_runs , config )
695-
699+
696700 experiment_duration_seconds = time .perf_counter () - experiment_start_time
697-
701+
698702 # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
699703 # rollout_id is used to differentiate the result from different completion_params
700704 if mode == "groupwise" :
@@ -730,15 +734,12 @@ async def _collect_result(config, lst):
730734 experiment_duration_seconds ,
731735 )
732736
733-
734-
735737 if not all (r .evaluation_result is not None for run_results in all_results for r in run_results ):
736738 raise AssertionError (
737739 "Some EvaluationRow instances are missing evaluation_result. "
738740 "Your @evaluation_test function must set `row.evaluation_result`"
739741 )
740742
741-
742743 except AssertionError :
743744 _log_eval_error (
744745 Status .eval_finished (),
0 commit comments