44import re
55import sys
66from dataclasses import replace
7- from typing import Any , Literal
7+ from typing import Any , Literal , Callable , AsyncGenerator
88
99from litellm .cost_calculator import cost_per_token
1010from tqdm import tqdm
3333AggregationMethod = Literal ["mean" , "max" , "min" , "bootstrap" ]
3434
3535
36- async def run_tasks_with_eval_progress (pointwise_tasks : list , run_idx : int ):
36+ async def run_tasks_with_eval_progress (
37+ pointwise_tasks : list [asyncio .Task [EvaluationRow ]], run_idx : int
38+ ) -> list [EvaluationRow ]:
3739 """
3840 Run evaluation tasks with a progress bar and proper cancellation handling.
3941
@@ -58,7 +60,7 @@ async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
5860 bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" ,
5961 ) as eval_pbar :
6062
61- async def task_with_progress (task ) :
63+ async def task_with_progress (task : asyncio . Task [ EvaluationRow ]) -> EvaluationRow :
6264 try :
6365 result = await task
6466 return result
@@ -77,7 +79,9 @@ async def task_with_progress(task):
7779 raise
7880
7981
80- async def run_tasks_with_run_progress (execute_run_func , num_runs , config ):
82+ async def run_tasks_with_run_progress (
83+ execute_run_func : Callable [[int , RolloutProcessorConfig ], Any ], num_runs : int , config : RolloutProcessorConfig
84+ ) -> None :
8185 """
8286 Run tasks with a parallel runs progress bar, preserving original logic.
8387
@@ -98,12 +102,12 @@ async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
98102 bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" ,
99103 ) as run_pbar :
100104
101- async def execute_run_with_progress (run_idx : int , config ) :
105+ async def execute_run_with_progress (run_idx : int , config : RolloutProcessorConfig ) -> Any :
102106 result = await execute_run_func (run_idx , config )
103107 run_pbar .update (1 )
104108 return result
105109
106- tasks = []
110+ tasks : list [ asyncio . Task [ Any ]] = []
107111 for run_idx in range (num_runs ):
108112 tasks .append (asyncio .create_task (execute_run_with_progress (run_idx , config )))
109113 try :
@@ -274,7 +278,7 @@ async def rollout_processor_with_retry(
274278 fresh_dataset : list [EvaluationRow ],
275279 config : RolloutProcessorConfig ,
276280 run_idx : int = 0 ,
277- ):
281+ ) -> AsyncGenerator [ EvaluationRow , None ] :
278282 """
279283 Wrapper around rollout_processor that handles retry logic using the Python backoff library.
280284
@@ -304,13 +308,13 @@ async def rollout_processor_with_retry(
304308
305309 # Create a single backoff-decorated retry function that can be reused
306310 @exception_config .get_backoff_decorator () # pyright: ignore[reportUntypedFunctionDecorator]
307- async def execute_row_with_backoff_retry (row : EvaluationRow ):
311+ async def execute_row_with_backoff_retry (row : EvaluationRow ) -> EvaluationRow :
308312 """Execute rollout for a single row with backoff retry."""
309313 retry_config = replace (config , kwargs = {** (config .kwargs or {}), "start_server" : False })
310314 retry_tasks = rollout_processor ([row ], retry_config )
311315 return await retry_tasks [0 ]
312316
313- async def execute_row_with_backoff (task : asyncio .Task , row : EvaluationRow ) -> EvaluationRow : # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
317+ async def execute_row_with_backoff (task : asyncio .Task [ EvaluationRow ] , row : EvaluationRow ) -> EvaluationRow :
314318 """Execute a single row task with backoff retry."""
315319
316320 try :
@@ -344,7 +348,9 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
344348 row .rollout_status = Status .rollout_error (repr (e ))
345349 return row
346350
347- async def execute_row_with_backoff_and_log (task : asyncio .Task , row : EvaluationRow ) -> EvaluationRow : # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
351+ async def execute_row_with_backoff_and_log (
352+ task : asyncio .Task [EvaluationRow ], row : EvaluationRow
353+ ) -> EvaluationRow :
348354 """Execute a single row task with backoff retry and logging."""
349355 result = await execute_row_with_backoff (task , row )
350356 # Log the row after execution completes (success or failure)
@@ -386,7 +392,7 @@ def sanitize_filename(text: str) -> str:
386392 return safe [:120 ]
387393
388394
389- def extract_effort_tag (params : dict ) -> str | None : # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
395+ def extract_effort_tag (params : dict [ str , Any ] ) -> str | None :
390396 """
391397 Extract effort tag from completion parameters for use in file naming.
392398
0 commit comments