2020 EvaluationRow ,
2121 EvaluationThreshold ,
2222 EvaluationThresholdDict ,
23- EvaluateResult ,
2423 Status ,
2524 EPParameters ,
2625)
2726from eval_protocol .pytest .dual_mode_wrapper import create_dual_mode_wrapper
2827from eval_protocol .pytest .evaluation_test_postprocess import postprocess
29- from eval_protocol .pytest .execution import execute_pytest , execute_pytest_with_exception_handling
28+ from eval_protocol .pytest .execution import execute_pytest_with_exception_handling
3029from eval_protocol .pytest .priority_scheduler import execute_priority_rollouts
3130from eval_protocol .pytest .generate_parameter_combinations import (
3231 ParameterizedTestKwargs ,
5655 AggregationMethod ,
5756 add_cost_metrics ,
5857 log_eval_status_and_rows ,
58+ normalize_fireworks_model ,
5959 parse_ep_completion_params ,
6060 parse_ep_completion_params_overwrite ,
6161 parse_ep_max_concurrent_rollouts ,
@@ -205,6 +205,7 @@ def evaluation_test(
205205 max_dataset_rows = parse_ep_max_rows (max_dataset_rows )
206206 completion_params = parse_ep_completion_params (completion_params )
207207 completion_params = parse_ep_completion_params_overwrite (completion_params )
208+ completion_params = [normalize_fireworks_model (cp ) for cp in completion_params ]
208209 original_completion_params = completion_params
209210 passed_threshold = parse_ep_passed_threshold (passed_threshold )
210211 data_loaders = parse_ep_dataloaders (data_loaders )
@@ -409,21 +410,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
409410
410411 rollout_processor .setup ()
411412
412- use_priority_scheduler = (
413- (
414- os .environ .get ("EP_USE_PRIORITY_SCHEDULER" , "0" ) == "1"
415- and not isinstance (rollout_processor , MCPGymRolloutProcessor )
416- )
417- )
413+ use_priority_scheduler = os .environ .get (
414+ "EP_USE_PRIORITY_SCHEDULER" , "0"
415+ ) == "1" and not isinstance (rollout_processor , MCPGymRolloutProcessor )
418416
419417 if use_priority_scheduler :
420418 microbatch_output_size = os .environ .get ("EP_MICRO_BATCH_OUTPUT_SIZE" , None )
421419 output_dir = os .environ .get ("EP_OUTPUT_DIR" , None )
422420 if microbatch_output_size and output_dir :
423- output_buffer = MicroBatchDataBuffer (num_runs = num_runs , batch_size = int (microbatch_output_size ), output_path_template = os .path .join (output_dir , "buffer_{index}.jsonl" ))
421+ output_buffer = MicroBatchDataBuffer (
422+ num_runs = num_runs ,
423+ batch_size = int (microbatch_output_size ),
424+ output_path_template = os .path .join (output_dir , "buffer_{index}.jsonl" ),
425+ )
424426 else :
425427 output_buffer = None
426-
428+
427429 try :
428430 priority_results = await execute_priority_rollouts (
429431 dataset = data ,
@@ -441,12 +443,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
441443 finally :
442444 if output_buffer :
443445 await output_buffer .close ()
444-
446+
445447 for res in priority_results :
446448 run_idx = (res .execution_metadata .extra or {}).get ("run_index" , 0 )
447449 if run_idx < len (all_results ):
448450 all_results [run_idx ].append (res )
449-
451+
450452 processed_rows_in_run .append (res )
451453
452454 postprocess (
@@ -462,6 +464,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
462464 )
463465
464466 else :
467+
465468 async def execute_run (run_idx : int , config : RolloutProcessorConfig ):
466469 nonlocal all_results
467470
@@ -506,9 +509,7 @@ async def _execute_pointwise_eval_with_semaphore(
506509 raise ValueError (
507510 f"Test function { test_func .__name__ } did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
508511 )
509- result .execution_metadata .eval_duration_seconds = (
510- time .perf_counter () - start_time
511- )
512+ result .execution_metadata .eval_duration_seconds = time .perf_counter () - start_time
512513 return result
513514
514515 async def _execute_groupwise_eval_with_semaphore (
@@ -519,7 +520,9 @@ async def _execute_groupwise_eval_with_semaphore(
519520 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {}
520521 primary_rollout_id = rows [0 ].execution_metadata .rollout_id if rows else None
521522 group_rollout_ids = [
522- r .execution_metadata .rollout_id for r in rows if r .execution_metadata .rollout_id
523+ r .execution_metadata .rollout_id
524+ for r in rows
525+ if r .execution_metadata .rollout_id
523526 ]
524527 async with rollout_logging_context (
525528 primary_rollout_id or "" ,
@@ -596,7 +599,9 @@ async def _collect_result(config, lst):
596599 row_groups [row .input_metadata .row_id ].append (row )
597600 tasks = []
598601 for _ , rows in row_groups .items ():
599- tasks .append (asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows )))
602+ tasks .append (
603+ asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows ))
604+ )
600605 results = []
601606 for task in tasks :
602607 res = await task
@@ -692,9 +697,9 @@ async def _collect_result(config, lst):
692697 # For other processors, create all tasks at once and run in parallel
693698 # Concurrency is now controlled by the shared semaphore in each rollout processor
694699 await run_tasks_with_run_progress (execute_run , num_runs , config )
695-
700+
696701 experiment_duration_seconds = time .perf_counter () - experiment_start_time
697-
702+
698703 # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
699704 # rollout_id is used to differentiate the result from different completion_params
700705 if mode == "groupwise" :
@@ -730,15 +735,12 @@ async def _collect_result(config, lst):
730735 experiment_duration_seconds ,
731736 )
732737
733-
734-
735738 if not all (r .evaluation_result is not None for run_results in all_results for r in run_results ):
736739 raise AssertionError (
737740 "Some EvaluationRow instances are missing evaluation_result. "
738741 "Your @evaluation_test function must set `row.evaluation_result`"
739742 )
740743
741-
742744 except AssertionError :
743745 _log_eval_error (
744746 Status .eval_finished (),
0 commit comments