11# pyright: reportPrivateUsage=false
22
33import asyncio
4+ from collections .abc import Callable
45import logging
56import time
6- import types
7- from pydantic_ai .models import Model
7+ from pydantic_ai .usage import UsageLimits
88from typing_extensions import override
99from eval_protocol .models import EvaluationRow , Message
10- from openai .types import CompletionUsage
1110from eval_protocol .pytest .rollout_processor import RolloutProcessor
1211from eval_protocol .pytest .types import RolloutProcessorConfig
1312from openai .types .chat import ChatCompletion , ChatCompletionMessage , ChatCompletionMessageParam
1413from openai .types .chat .chat_completion import Choice as ChatCompletionChoice
15- from pydantic_ai .models .anthropic import AnthropicModel
16- from pydantic_ai .models .openai import OpenAIModel
17- from pydantic_ai .models .google import GoogleModel
1814from pydantic import TypeAdapter
19- from pydantic_ai .messages import ModelMessage
20- from pydantic_ai ._utils import generate_tool_call_id
2115from pydantic_ai import Agent
16+ from pydantic_ai ._utils import generate_tool_call_id
17+ from pydantic_ai .messages import ModelMessage
2218from pydantic_ai .messages import (
2319 ModelRequest ,
2420 SystemPromptPart ,
2521 ToolReturnPart ,
2622 UserPromptPart ,
2723)
24+ from pydantic_ai .models .openai import OpenAIModel
2825from pydantic_ai .providers .openai import OpenAIProvider
2926
3027logger = logging .getLogger (__name__ )
@@ -34,64 +31,29 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3431 """Rollout processor for Pydantic AI agents. Mainly converts
3532 EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
3633
37- def __init__ (self ):
34+ def __init__ (
35+ self ,
36+ agent_factory : Callable [[RolloutProcessorConfig ], Agent ],
37+ usage_limits : UsageLimits | None = None ,
38+ ):
3839 # dummy model used for its helper functions for processing messages
39- self .util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
40+ self ._util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
41+ self ._setup_agent = agent_factory
4042
4143 @override
4244 def __call__ (self , rows : list [EvaluationRow ], config : RolloutProcessorConfig ) -> list [asyncio .Task [EvaluationRow ]]:
4345 """Create agent rollout tasks and return them for external handling."""
4446
4547 semaphore = config .semaphore
4648
47- # validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
48- if "agent" not in config .kwargs :
49- raise ValueError ("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance" )
50- if not isinstance (config .kwargs ["agent" ], Agent ) and not isinstance (
51- config .kwargs ["agent" ], types .FunctionType
52- ):
53- raise ValueError (
54- "kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
55- )
56-
57- if isinstance (config .kwargs ["agent" ], types .FunctionType ):
58- setup_agent = config .kwargs ["agent" ]
59- if not isinstance (config .completion_params ["model" ], dict ):
60- raise ValueError (
61- "completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
62- )
63- kwargs : dict [str , Model ] = {}
64- for k , v in config .completion_params ["model" ].items (): # pyright: ignore[reportUnknownVariableType]
65- if v ["model" ] and v ["model" ].startswith ("anthropic:" ): # pyright: ignore[reportUnknownMemberType]
66- kwargs [k ] = AnthropicModel (
67- v ["model" ].removeprefix ("anthropic:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
68- )
69- elif v ["model" ] and v ["model" ].startswith ("google:" ): # pyright: ignore[reportUnknownMemberType]
70- kwargs [k ] = GoogleModel (
71- v ["model" ].removeprefix ("google:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
72- )
73- else :
74- kwargs [k ] = OpenAIModel (
75- v ["model" ], # pyright: ignore[reportUnknownArgumentType]
76- provider = v ["provider" ], # pyright: ignore[reportUnknownArgumentType]
77- )
78- agent_instance : Agent = setup_agent (** kwargs ) # pyright: ignore[reportAny]
79- model = None
80- else :
81- agent_instance = config .kwargs ["agent" ] # pyright: ignore[reportAssignmentType]
82- model = OpenAIModel (
83- config .completion_params ["model" ], # pyright: ignore[reportAny]
84- provider = config .completion_params ["provider" ], # pyright: ignore[reportAny]
85- )
49+ agent = self ._setup_agent (config )
8650
8751 async def process_row (row : EvaluationRow ) -> EvaluationRow :
8852 """Process a single row with agent rollout."""
8953 start_time = time .perf_counter ()
9054
9155 model_messages = [self .convert_ep_message_to_pyd_message (m , row ) for m in row .messages ]
92- response = await agent_instance .run (
93- message_history = model_messages , model = model , usage_limits = config .kwargs .get ("usage_limits" )
94- )
56+ response = await agent .run (message_history = model_messages , usage_limits = config .kwargs .get ("usage_limits" ))
9557 row .messages = await self .convert_pyd_message_to_ep_message (response .all_messages ())
9658
9759 # TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
@@ -116,15 +78,15 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
11678 return tasks
11779
11880 async def convert_pyd_message_to_ep_message (self , messages : list [ModelMessage ]) -> list [Message ]:
119- oai_messages : list [ChatCompletionMessageParam ] = await self .util ._map_messages (messages )
81+ oai_messages : list [ChatCompletionMessageParam ] = await self ._util ._map_messages (messages )
12082 return [Message (** m ) for m in oai_messages ] # pyright: ignore[reportArgumentType]
12183
12284 def convert_ep_message_to_pyd_message (self , message : Message , row : EvaluationRow ) -> ModelMessage :
12385 if message .role == "assistant" :
12486 type_adapter = TypeAdapter (ChatCompletionMessage )
12587 oai_message = type_adapter .validate_python (message )
12688 # Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
127- return self .util ._process_response (
89+ return self ._util ._process_response (
12890 ChatCompletion (
12991 choices = [ChatCompletionChoice (message = oai_message , finish_reason = "stop" , index = 0 )],
13092 object = "chat.completion" ,
@@ -157,5 +119,4 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
157119 )
158120 ]
159121 )
160- else :
161- raise ValueError (f"Unknown role: { message .role } " )
122+ raise ValueError (f"Unknown role: { message .role } " )
0 commit comments