1- from typing import List , Set
21import asyncio
2+ from typing import Any
3+ from typing_extensions import override
34
45from eval_protocol .dataset_logger .dataset_logger import DatasetLogger
56from eval_protocol .models import EvaluationRow
1112class TrackingRolloutProcessor (RolloutProcessor ):
1213 """Custom rollout processor that tracks which rollout IDs are generated during rollout phase."""
1314
14- def __init__ (self , shared_rollout_ids : Set [str ]):
15- self .shared_rollout_ids = shared_rollout_ids
15+ def __init__ (self , shared_rollout_ids : set [str ]):
16+ self .shared_rollout_ids : set [ str ] = shared_rollout_ids
1617
17- def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
18+ @override
19+ def __call__ (self , rows : list [EvaluationRow ], config : RolloutProcessorConfig ) -> list [asyncio .Task [EvaluationRow ]]:
1820 """Process rows and track rollout IDs generated during rollout phase."""
1921
2022 async def process_row (row : EvaluationRow ) -> EvaluationRow :
2123 # Track this rollout ID as being generated during rollout phase
24+ if row .execution_metadata .rollout_id is None :
25+ raise ValueError ("Rollout ID is None" )
2226 self .shared_rollout_ids .add (row .execution_metadata .rollout_id )
2327 return row
2428
@@ -30,13 +34,17 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3034class TrackingLogger (DatasetLogger ):
3135 """Custom logger that tracks all rollout IDs that are logged."""
3236
33- def __init__ (self , shared_rollout_ids : Set [str ]):
34- self .shared_rollout_ids = shared_rollout_ids
37+ def __init__ (self , shared_rollout_ids : set [str ]):
38+ self .shared_rollout_ids : set [ str ] = shared_rollout_ids
3539
40+ @override
3641 def log (self , row : EvaluationRow ):
42+ if row .execution_metadata .rollout_id is None :
43+ raise ValueError ("Rollout ID is None" )
3744 self .shared_rollout_ids .add (row .execution_metadata .rollout_id )
3845
39- def read (self ):
46+ @override
47+ def read (self , row_id : str | None = None ) -> list [EvaluationRow ]:
4048 return []
4149
4250
@@ -48,7 +56,7 @@ async def test_assertion_error_no_new_rollouts():
4856 from eval_protocol .pytest .evaluation_test import evaluation_test
4957
5058 # Create shared set to track rollout IDs generated during rollout phase
51- shared_rollout_ids : Set [str ] = set ()
59+ shared_rollout_ids : set [str ] = set ()
5260
5361 # Create custom processor and logger for tracking with shared set
5462 rollout_processor = TrackingRolloutProcessor (shared_rollout_ids )
@@ -57,7 +65,7 @@ async def test_assertion_error_no_new_rollouts():
5765 input_dataset : list [str ] = [
5866 "tests/pytest/data/markdown_dataset.jsonl" ,
5967 ]
60- completion_params : list [dict ] = [{"temperature" : 0.0 , "model" : "dummy/local-model" }]
68+ completion_params : list [dict [ str , Any ]] = [{"temperature" : 0.0 , "model" : "dummy/local-model" }] # pyright: ignore[reportExplicitAny ]
6169
6270 @evaluation_test (
6371 input_dataset = input_dataset ,
@@ -81,7 +89,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
8189 # This should fail due to threshold not being met
8290 for ds_path in input_dataset :
8391 for completion_param in completion_params :
84- await eval_fn (dataset_path = ds_path , completion_params = completion_param )
92+ await eval_fn (dataset_path = [ ds_path ] , completion_params = completion_param ) # pyright: ignore[reportCallIssue]
8593 except AssertionError :
8694 # Expected - the threshold check should fail
8795 pass
0 commit comments