2020 EvaluationRow ,
2121 EvaluationThreshold ,
2222 EvaluationThresholdDict ,
23- EvaluateResult ,
2423 Status ,
2524 EPParameters ,
2625)
2726from eval_protocol .pytest .dual_mode_wrapper import create_dual_mode_wrapper
2827from eval_protocol .pytest .evaluation_test_postprocess import postprocess
29- from eval_protocol .pytest .execution import execute_pytest , execute_pytest_with_exception_handling
28+ from eval_protocol .pytest .execution import execute_pytest_with_exception_handling
3029from eval_protocol .pytest .priority_scheduler import execute_priority_rollouts
3130from eval_protocol .pytest .generate_parameter_combinations import (
3231 ParameterizedTestKwargs ,
5655 AggregationMethod ,
5756 add_cost_metrics ,
5857 log_eval_status_and_rows ,
58+ normalize_fireworks_model ,
5959 parse_ep_completion_params ,
6060 parse_ep_completion_params_overwrite ,
6161 parse_ep_max_concurrent_rollouts ,
@@ -93,8 +93,8 @@ def evaluation_test(
9393 filtered_row_ids : Sequence [str ] | None = None ,
9494 max_dataset_rows : int | None = None ,
9595 mcp_config_path : str | None = None ,
96- max_concurrent_rollouts : int = 8 ,
97- max_concurrent_evaluations : int = 64 ,
96+ max_concurrent_rollouts : int = 96 ,
97+ max_concurrent_evaluations : int = 96 ,
9898 server_script_path : str | None = None ,
9999 steps : int = 30 ,
100100 mode : EvaluationTestMode = "pointwise" ,
@@ -205,6 +205,7 @@ def evaluation_test(
205205 max_dataset_rows = parse_ep_max_rows (max_dataset_rows )
206206 completion_params = parse_ep_completion_params (completion_params )
207207 completion_params = parse_ep_completion_params_overwrite (completion_params )
208+ completion_params = [normalize_fireworks_model (cp ) for cp in completion_params ]
208209 original_completion_params = completion_params
209210 passed_threshold = parse_ep_passed_threshold (passed_threshold )
210211 data_loaders = parse_ep_dataloaders (data_loaders )
@@ -365,6 +366,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
365366 row .input_metadata .row_id = generate_id (seed = 0 , index = index )
366367
367368 completion_params = kwargs ["completion_params" ] if "completion_params" in kwargs else None
369+ completion_params = normalize_fireworks_model (completion_params )
368370 # Create eval metadata with test function info and current commit hash
369371 eval_metadata = EvalMetadata (
370372 name = test_func .__name__ ,
@@ -409,21 +411,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
409411
410412 rollout_processor .setup ()
411413
412- use_priority_scheduler = (
413- (
414- os .environ .get ("EP_USE_PRIORITY_SCHEDULER" , "0" ) == "1"
415- and not isinstance (rollout_processor , MCPGymRolloutProcessor )
416- )
417- )
414+ use_priority_scheduler = os .environ .get (
415+ "EP_USE_PRIORITY_SCHEDULER" , "0"
416+ ) == "1" and not isinstance (rollout_processor , MCPGymRolloutProcessor )
418417
419418 if use_priority_scheduler :
420419 microbatch_output_size = os .environ .get ("EP_MICRO_BATCH_OUTPUT_SIZE" , None )
421420 output_dir = os .environ .get ("EP_OUTPUT_DIR" , None )
422421 if microbatch_output_size and output_dir :
423- output_buffer = MicroBatchDataBuffer (num_runs = num_runs , batch_size = int (microbatch_output_size ), output_path_template = os .path .join (output_dir , "buffer_{index}.jsonl" ))
422+ output_buffer = MicroBatchDataBuffer (
423+ num_runs = num_runs ,
424+ batch_size = int (microbatch_output_size ),
425+ output_path_template = os .path .join (output_dir , "buffer_{index}.jsonl" ),
426+ )
424427 else :
425428 output_buffer = None
426-
429+
427430 try :
428431 priority_results = await execute_priority_rollouts (
429432 dataset = data ,
@@ -441,12 +444,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
441444 finally :
442445 if output_buffer :
443446 await output_buffer .close ()
444-
447+
445448 for res in priority_results :
446449 run_idx = (res .execution_metadata .extra or {}).get ("run_index" , 0 )
447450 if run_idx < len (all_results ):
448451 all_results [run_idx ].append (res )
449-
452+
450453 processed_rows_in_run .append (res )
451454
452455 postprocess (
@@ -462,6 +465,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
462465 )
463466
464467 else :
468+
465469 async def execute_run (run_idx : int , config : RolloutProcessorConfig ):
466470 nonlocal all_results
467471
@@ -506,9 +510,7 @@ async def _execute_pointwise_eval_with_semaphore(
506510 raise ValueError (
507511 f"Test function { test_func .__name__ } did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
508512 )
509- result .execution_metadata .eval_duration_seconds = (
510- time .perf_counter () - start_time
511- )
513+ result .execution_metadata .eval_duration_seconds = time .perf_counter () - start_time
512514 return result
513515
514516 async def _execute_groupwise_eval_with_semaphore (
@@ -519,7 +521,9 @@ async def _execute_groupwise_eval_with_semaphore(
519521 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {}
520522 primary_rollout_id = rows [0 ].execution_metadata .rollout_id if rows else None
521523 group_rollout_ids = [
522- r .execution_metadata .rollout_id for r in rows if r .execution_metadata .rollout_id
524+ r .execution_metadata .rollout_id
525+ for r in rows
526+ if r .execution_metadata .rollout_id
523527 ]
524528 async with rollout_logging_context (
525529 primary_rollout_id or "" ,
@@ -596,7 +600,9 @@ async def _collect_result(config, lst):
596600 row_groups [row .input_metadata .row_id ].append (row )
597601 tasks = []
598602 for _ , rows in row_groups .items ():
599- tasks .append (asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows )))
603+ tasks .append (
604+ asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows ))
605+ )
600606 results = []
601607 for task in tasks :
602608 res = await task
@@ -692,9 +698,9 @@ async def _collect_result(config, lst):
692698 # For other processors, create all tasks at once and run in parallel
693699 # Concurrency is now controlled by the shared semaphore in each rollout processor
694700 await run_tasks_with_run_progress (execute_run , num_runs , config )
695-
701+
696702 experiment_duration_seconds = time .perf_counter () - experiment_start_time
697-
703+
698704 # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
699705 # rollout_id is used to differentiate the result from different completion_params
700706 if mode == "groupwise" :
@@ -730,15 +736,12 @@ async def _collect_result(config, lst):
730736 experiment_duration_seconds ,
731737 )
732738
733-
734-
735739 if not all (r .evaluation_result is not None for run_results in all_results for r in run_results ):
736740 raise AssertionError (
737741 "Some EvaluationRow instances are missing evaluation_result. "
738742 "Your @evaluation_test function must set `row.evaluation_result`"
739743 )
740744
741-
742745 except AssertionError :
743746 _log_eval_error (
744747 Status .eval_finished (),
0 commit comments