Skip to content

Commit 4d63471

Browse files
author
Dylan Huang
committed
fix types
1 parent 2e14852 commit 4d63471

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
# pyright: reportPrivateUsage=false
2+
13
import asyncio
24
import logging
35
import types
4-
from typing import List
5-
6-
from attr import dataclass
7-
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
8-
6+
from pydantic_ai.models import Model
7+
from typing_extensions import override
98
from eval_protocol.models import EvaluationRow, Message
109
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1110
from eval_protocol.pytest.types import RolloutProcessorConfig
@@ -25,7 +24,6 @@
2524
UserPromptPart,
2625
)
2726
from pydantic_ai.providers.openai import OpenAIProvider
28-
from typing_extensions import TypedDict
2927

3028
logger = logging.getLogger(__name__)
3129

@@ -38,7 +36,8 @@ def __init__(self):
3836
# dummy model used for its helper functions for processing messages
3937
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
4038

41-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
39+
@override
40+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
4241
"""Create agent rollout tasks and return them for external handling."""
4342

4443
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
@@ -60,28 +59,28 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6059
raise ValueError(
6160
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
6261
)
63-
kwargs: dict = {}
64-
for k, v in config.completion_params["model"].items():
65-
if v["model"] and v["model"].startswith("anthropic:"):
62+
kwargs: dict[str, Model] = {}
63+
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
64+
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
6665
kwargs[k] = AnthropicModel(
67-
v["model"].removeprefix("anthropic:"),
66+
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
6867
)
69-
elif v["model"] and v["model"].startswith("google:"):
68+
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
7069
kwargs[k] = GoogleModel(
71-
v["model"].removeprefix("google:"),
70+
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
7271
)
7372
else:
7473
kwargs[k] = OpenAIModel(
75-
v["model"],
76-
provider=v["provider"],
74+
v["model"], # pyright: ignore[reportUnknownArgumentType]
75+
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
7776
)
78-
agent_instance: Agent = setup_agent(**kwargs)
77+
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
7978
model = None
8079
else:
81-
agent_instance = config.kwargs["agent"]
80+
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
8281
model = OpenAIModel(
83-
config.completion_params["model"],
84-
provider=config.completion_params["provider"],
82+
config.completion_params["model"], # pyright: ignore[reportAny]
83+
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
8584
)
8685

8786
async def process_row(row: EvaluationRow) -> EvaluationRow:
@@ -104,7 +103,7 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
104103

105104
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
106105
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
107-
return [Message(role=m["role"], **m) for m in oai_messages]
106+
return [Message(role=m["role"], **m) for m in oai_messages] # pyright: ignore[reportArgumentType]
108107

109108
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
110109
if message.role == "assistant":

0 commit comments

Comments
 (0)