@@ -72,7 +72,7 @@ def evaluation_test(
7272 input_dataset : Sequence [DatasetPathParam ] | None = None ,
7373 input_rows : Sequence [list [EvaluationRow ]] | None = None ,
7474 data_loaders : Sequence [EvaluationDataLoader ] | EvaluationDataLoader | None = None ,
75- dataset_adapter : Callable [[list [dict [str , Any ]]], Dataset ] = default_dataset_adapter , # pyright: ignore[reportExplicitAny]
75+ dataset_adapter : Callable [[list [dict [str , Any ]]], Dataset ] = default_dataset_adapter ,
7676 rollout_processor : RolloutProcessor | None = None ,
7777 evaluation_test_kwargs : Sequence [EvaluationInputParam | None ] | None = None ,
7878 rollout_processor_kwargs : RolloutProcessorInputParam | None = None ,
@@ -418,9 +418,7 @@ async def _execute_groupwise_eval_with_semaphore(
418418 all_results [run_idx ] = results
419419 elif mode == "groupwise" :
420420 # rollout all the completion_params for the same row at once, and then send the output to the test_func
421- row_groups = defaultdict ( # pyright: ignore[reportUnknownVariableType]
422- list
423- ) # key: row_id, value: list of rollout_result
421+ row_groups = defaultdict (list ) # key: row_id, value: list of rollout_result
424422 tasks : list [asyncio .Task [list [EvaluationRow ]]] = []
425423 # completion_groups = []
426424 for idx , cp in enumerate (original_completion_params ):
@@ -435,13 +433,13 @@ async def _execute_groupwise_eval_with_semaphore(
435433 )
436434 lst = []
437435
438- async def _collect_result (config , lst ): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
436+ async def _collect_result (config , lst ):
439437 result = []
440438 async for row in rollout_processor_with_retry (
441439 rollout_processor , lst , config , run_idx
442440 ): # pyright: ignore[reportUnknownArgumentType]
443- result .append (row ) # pyright: ignore[reportUnknownMemberType]
444- return result # pyright: ignore[reportUnknownVariableType]
441+ result .append (row )
442+ return result
445443
446444 for ori_row in fresh_dataset :
447445 copied_row = ori_row .model_copy (deep = True )
@@ -450,33 +448,32 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
450448 str (ori_row .execution_metadata .rollout_id ) + "_" + str (idx )
451449 )
452450 copied_row .input_metadata .completion_params = cp if cp is not None else {}
453- lst .append (copied_row ) # pyright: ignore[reportUnknownMemberType]
454- tasks .append (asyncio .create_task (_collect_result (config , lst ))) # pyright: ignore[reportUnknownArgumentType]
451+ lst .append (copied_row )
452+ tasks .append (asyncio .create_task (_collect_result (config , lst )))
455453 rollout_results = await asyncio .gather (* tasks )
456454 for result in rollout_results :
457455 for row in result :
458- row_groups [row .input_metadata .row_id ].append (row ) # pyright: ignore[reportUnknownMemberType]
456+ row_groups [row .input_metadata .row_id ].append (row )
459457 tasks = []
460- for _ , rows in row_groups .items (): # pyright: ignore[reportUnknownVariableType]
461- tasks .append (asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows ))) # pyright: ignore[reportUnknownArgumentType]
458+ for _ , rows in row_groups .items ():
459+ tasks .append (asyncio .create_task (_execute_groupwise_eval_with_semaphore (rows = rows )))
462460 results = []
463461 for task in tasks :
464462 res = await task
465- results .extend (res ) # pyright: ignore[reportUnknownMemberType]
463+ results .extend (res )
466464 all_results [run_idx ] = results
467465 else :
468466 # Batch mode: collect all results first, then evaluate (no pipelining)
469467 input_dataset = []
470468 async for row in rollout_processor_with_retry (
471469 rollout_processor , fresh_dataset , config , run_idx
472470 ):
473- input_dataset .append (row ) # pyright: ignore[reportUnknownMemberType]
474-
471+ input_dataset .append (row )
475472 # NOTE: we will still evaluate errored rows (give users control over this)
476473 # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
477474 results = await execute_pytest (
478475 test_func ,
479- processed_dataset = input_dataset , # pyright: ignore[reportUnknownArgumentType]
476+ processed_dataset = input_dataset ,
480477 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
481478 )
482479 if (
@@ -539,16 +536,16 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
539536 # for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
540537 # rollout_id is used to differentiate the result from different completion_params
541538 if mode == "groupwise" :
542- results_by_group = [ # pyright: ignore[reportUnknownVariableType]
539+ results_by_group = [
543540 [[] for _ in range (num_runs )] for _ in range (len (original_completion_params ))
544541 ]
545542 for i_run , result in enumerate (all_results ):
546543 for r in result :
547544 completion_param_idx = int (r .execution_metadata .rollout_id .split ("_" )[1 ]) # pyright: ignore[reportOptionalMemberAccess]
548- results_by_group [completion_param_idx ][i_run ].append (r ) # pyright: ignore[reportUnknownMemberType]
549- for rollout_id , result in enumerate (results_by_group ): # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
545+ results_by_group [completion_param_idx ][i_run ].append (r )
546+ for rollout_id , result in enumerate (results_by_group ):
550547 postprocess (
551- result , # pyright: ignore[reportUnknownArgumentType]
548+ result ,
552549 aggregation_method ,
553550 passed_threshold ,
554551 active_logger ,
@@ -600,7 +597,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
600597 pytest_wrapper = pytest .mark .asyncio (pytest_wrapper )
601598
602599 # Create the dual mode wrapper
603- dual_mode_wrapper = create_dual_mode_wrapper ( # pyright: ignore[reportUnknownVariableType]
600+ dual_mode_wrapper = create_dual_mode_wrapper (
604601 test_func , mode , max_concurrent_rollouts , max_concurrent_evaluations , pytest_wrapper
605602 )
606603
0 commit comments