1+ # pyright: reportPrivateUsage=false
2+
13import asyncio
24import logging
35import types
4- from typing import List
5-
6- from attr import dataclass
7- from openai .types .chat .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
8-
6+ from pydantic_ai .models import Model
7+ from typing_extensions import override
98from eval_protocol .models import EvaluationRow , Message
109from eval_protocol .pytest .rollout_processor import RolloutProcessor
1110from eval_protocol .pytest .types import RolloutProcessorConfig
2524 UserPromptPart ,
2625)
2726from pydantic_ai .providers .openai import OpenAIProvider
28- from typing_extensions import TypedDict
2927
3028logger = logging .getLogger (__name__ )
3129
@@ -38,7 +36,8 @@ def __init__(self):
3836 # dummy model used for its helper functions for processing messages
3937 self .util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
4038
41- def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
39+ @override
40+ def __call__ (self , rows : list [EvaluationRow ], config : RolloutProcessorConfig ) -> list [asyncio .Task [EvaluationRow ]]:
4241 """Create agent rollout tasks and return them for external handling."""
4342
4443 max_concurrent = getattr (config , "max_concurrent_rollouts" , 8 ) or 8
@@ -60,28 +59,28 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6059 raise ValueError (
6160 "completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
6261 )
63- kwargs : dict = {}
64- for k , v in config .completion_params ["model" ].items ():
65- if v ["model" ] and v ["model" ].startswith ("anthropic:" ):
62+ kwargs : dict [ str , Model ] = {}
63+ for k , v in config .completion_params ["model" ].items (): # pyright: ignore[reportUnknownVariableType]
64+ if v ["model" ] and v ["model" ].startswith ("anthropic:" ): # pyright: ignore[reportUnknownMemberType]
6665 kwargs [k ] = AnthropicModel (
67- v ["model" ].removeprefix ("anthropic:" ),
66+ v ["model" ].removeprefix ("anthropic:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
6867 )
69- elif v ["model" ] and v ["model" ].startswith ("google:" ):
68+ elif v ["model" ] and v ["model" ].startswith ("google:" ): # pyright: ignore[reportUnknownMemberType]
7069 kwargs [k ] = GoogleModel (
71- v ["model" ].removeprefix ("google:" ),
70+ v ["model" ].removeprefix ("google:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
7271 )
7372 else :
7473 kwargs [k ] = OpenAIModel (
75- v ["model" ],
76- provider = v ["provider" ],
74+ v ["model" ], # pyright: ignore[reportUnknownArgumentType]
75+ provider = v ["provider" ], # pyright: ignore[reportUnknownArgumentType]
7776 )
78- agent_instance : Agent = setup_agent (** kwargs )
77+ agent_instance : Agent = setup_agent (** kwargs ) # pyright: ignore[reportAny]
7978 model = None
8079 else :
81- agent_instance = config .kwargs ["agent" ]
80+ agent_instance = config .kwargs ["agent" ] # pyright: ignore[reportAssignmentType]
8281 model = OpenAIModel (
83- config .completion_params ["model" ],
84- provider = config .completion_params ["provider" ],
82+ config .completion_params ["model" ], # pyright: ignore[reportAny]
83+ provider = config .completion_params ["provider" ], # pyright: ignore[reportAny]
8584 )
8685
8786 async def process_row (row : EvaluationRow ) -> EvaluationRow :
@@ -104,7 +103,7 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
104103
105104 async def convert_pyd_message_to_ep_message (self , messages : list [ModelMessage ]) -> list [Message ]:
106105 oai_messages : list [ChatCompletionMessageParam ] = await self .util ._map_messages (messages )
107- return [Message (role = m ["role" ], ** m ) for m in oai_messages ]
106+ return [Message (role = m ["role" ], ** m ) for m in oai_messages ] # pyright: ignore[reportArgumentType]
108107
109108 def convert_ep_message_to_pyd_message (self , message : Message , row : EvaluationRow ) -> ModelMessage :
110109 if message .role == "assistant" :
0 commit comments