Skip to content

Commit c6e0a5a

Browse files
author
Dylan Huang
authored
Refactor Pydantic AI Rollout processor to use factory pattern (#164)
* save * TODO: refactor rolloutprocessor to not use __call__ * save * save * factory pattern works * refactor test_pydantic_multi_agent to work with factory setup * fix test_pydantic_agent.py
1 parent d563336 commit c6e0a5a

File tree

5 files changed

+64
-97
lines changed

5 files changed

+64
-97
lines changed
Lines changed: 17 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
# pyright: reportPrivateUsage=false
22

33
import asyncio
4+
from collections.abc import Callable
45
import logging
56
import time
6-
import types
7-
from pydantic_ai.models import Model
7+
from pydantic_ai.usage import UsageLimits
88
from typing_extensions import override
99
from eval_protocol.models import EvaluationRow, Message
10-
from openai.types import CompletionUsage
1110
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1211
from eval_protocol.pytest.types import RolloutProcessorConfig
1312
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
1413
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
15-
from pydantic_ai.models.anthropic import AnthropicModel
16-
from pydantic_ai.models.openai import OpenAIModel
17-
from pydantic_ai.models.google import GoogleModel
1814
from pydantic import TypeAdapter
19-
from pydantic_ai.messages import ModelMessage
20-
from pydantic_ai._utils import generate_tool_call_id
2115
from pydantic_ai import Agent
16+
from pydantic_ai._utils import generate_tool_call_id
17+
from pydantic_ai.messages import ModelMessage
2218
from pydantic_ai.messages import (
2319
ModelRequest,
2420
SystemPromptPart,
2521
ToolReturnPart,
2622
UserPromptPart,
2723
)
24+
from pydantic_ai.models.openai import OpenAIModel
2825
from pydantic_ai.providers.openai import OpenAIProvider
2926

3027
logger = logging.getLogger(__name__)
@@ -34,64 +31,29 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3431
"""Rollout processor for Pydantic AI agents. Mainly converts
3532
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
3633

37-
def __init__(self):
34+
def __init__(
35+
self,
36+
agent_factory: Callable[[RolloutProcessorConfig], Agent],
37+
usage_limits: UsageLimits | None = None,
38+
):
3839
# dummy model used for its helper functions for processing messages
39-
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
40+
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
41+
self._setup_agent = agent_factory
4042

4143
@override
4244
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
4345
"""Create agent rollout tasks and return them for external handling."""
4446

4547
semaphore = config.semaphore
4648

47-
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
48-
if "agent" not in config.kwargs:
49-
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
50-
if not isinstance(config.kwargs["agent"], Agent) and not isinstance(
51-
config.kwargs["agent"], types.FunctionType
52-
):
53-
raise ValueError(
54-
"kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
55-
)
56-
57-
if isinstance(config.kwargs["agent"], types.FunctionType):
58-
setup_agent = config.kwargs["agent"]
59-
if not isinstance(config.completion_params["model"], dict):
60-
raise ValueError(
61-
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
62-
)
63-
kwargs: dict[str, Model] = {}
64-
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
65-
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
66-
kwargs[k] = AnthropicModel(
67-
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
68-
)
69-
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
70-
kwargs[k] = GoogleModel(
71-
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
72-
)
73-
else:
74-
kwargs[k] = OpenAIModel(
75-
v["model"], # pyright: ignore[reportUnknownArgumentType]
76-
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
77-
)
78-
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
79-
model = None
80-
else:
81-
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
82-
model = OpenAIModel(
83-
config.completion_params["model"], # pyright: ignore[reportAny]
84-
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
85-
)
49+
agent = self._setup_agent(config)
8650

8751
async def process_row(row: EvaluationRow) -> EvaluationRow:
8852
"""Process a single row with agent rollout."""
8953
start_time = time.perf_counter()
9054

9155
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
92-
response = await agent_instance.run(
93-
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
94-
)
56+
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
9557
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
9658

9759
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
@@ -116,15 +78,15 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
11678
return tasks
11779

11880
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
119-
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
81+
oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages)
12082
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
12183

12284
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
12385
if message.role == "assistant":
12486
type_adapter = TypeAdapter(ChatCompletionMessage)
12587
oai_message = type_adapter.validate_python(message)
12688
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
127-
return self.util._process_response(
89+
return self._util._process_response(
12890
ChatCompletion(
12991
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
13092
object="chat.completion",
@@ -157,5 +119,4 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
157119
)
158120
]
159121
)
160-
else:
161-
raise ValueError(f"Unknown role: {message.role}")
122+
raise ValueError(f"Unknown role: {message.role}")

eval_protocol/pytest/evaluation_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,6 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
255255
row.input_metadata.row_id = generate_id(seed=0, index=index)
256256

257257
completion_params = kwargs["completion_params"]
258-
if completion_params and ("model" not in completion_params or not completion_params["model"]):
259-
raise ValueError(
260-
"No model provided. Please provide a model in the completion parameters object."
261-
)
262-
263258
# Create eval metadata with test function info and current commit hash
264259
eval_metadata = EvalMetadata(
265260
name=test_func.__name__,

tests/chinook/pydantic/test_pydantic_chinook.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from eval_protocol.pytest import evaluation_test
77

88
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9+
from eval_protocol.pytest.types import RolloutProcessorConfig
910
from tests.chinook.pydantic.agent import setup_agent
1011
import os
1112
from pydantic_ai.models.openai import OpenAIModel
@@ -20,21 +21,23 @@
2021
)
2122

2223

24+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
25+
model_name = config.completion_params["model"]
26+
provider = config.completion_params["provider"]
27+
model = OpenAIModel(model_name, provider=provider)
28+
return setup_agent(model)
29+
30+
2331
@pytest.mark.asyncio
2432
@evaluation_test(
2533
input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]],
2634
completion_params=[
2735
{
28-
"model": {
29-
"orchestrator_agent_model": {
30-
"model": "accounts/fireworks/models/kimi-k2-instruct",
31-
"provider": "fireworks",
32-
}
33-
}
36+
"model": "accounts/fireworks/models/kimi-k2-instruct",
37+
"provider": "fireworks",
3438
},
3539
],
36-
rollout_processor=PydanticAgentRolloutProcessor(),
37-
rollout_processor_kwargs={"agent": setup_agent},
40+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
3841
mode="pointwise",
3942
)
4043
async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
@@ -92,16 +95,11 @@ class Response(BaseModel):
9295
input_rows=[collect_dataset()],
9396
completion_params=[
9497
{
95-
"model": {
96-
"orchestrator_agent_model": {
97-
"model": "accounts/fireworks/models/kimi-k2-instruct",
98-
"provider": "fireworks",
99-
}
100-
}
98+
"model": "accounts/fireworks/models/kimi-k2-instruct",
99+
"provider": "fireworks",
101100
},
102101
],
103-
rollout_processor=PydanticAgentRolloutProcessor(),
104-
rollout_processor_kwargs={"agent": setup_agent},
102+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
105103
mode="pointwise",
106104
)
107105
async def test_complex_queries(row: EvaluationRow) -> EvaluationRow:

tests/pytest/test_pydantic_agent.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
1+
from pydantic_ai.agent import Agent
2+
from pydantic_ai.models.openai import OpenAIModel
13
import pytest
24

35
from eval_protocol.models import EvaluationRow, Message
46
from eval_protocol.pytest import evaluation_test
5-
from pydantic_ai import Agent
67

78
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9+
from eval_protocol.pytest.types import RolloutProcessorConfig
810

9-
agent = Agent()
11+
12+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
13+
model = OpenAIModel(config.completion_params["model"], provider="fireworks")
14+
return Agent(model=model)
1015

1116

1217
@pytest.mark.asyncio
1318
@evaluation_test(
1419
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
1520
completion_params=[
16-
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
21+
{"model": "accounts/fireworks/models/gpt-oss-120b"},
1722
],
18-
rollout_processor=PydanticAgentRolloutProcessor(),
19-
rollout_processor_kwargs={"agent": agent},
23+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
2024
mode="pointwise",
2125
)
2226
async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:

tests/pytest/test_pydantic_multi_agent.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""
22
Copied and modified for eval-protocol from https://ai.pydantic.dev/multi-agent-applications/#agent-delegation
33
4-
To test your Pydantic AI multi-agent application, you can pass a function that
5-
sets up the agents and their tools. The function should accept parameters that
6-
map a model to each agent. In completion_params, you can provide mappings of
7-
model to agent based on key.
4+
To test your Pydantic AI multi-agent application, you can pass a factory that
5+
sets up the agenet based on the completion_params. The function should accept a
6+
RolloutProcessorConfig. In completion_params, you can provide mappings of model
7+
to agent based on key.
88
"""
99

10+
from pydantic_ai.models.openai import OpenAIModel
1011
import pytest
1112

1213
from eval_protocol.models import EvaluationRow, Message
@@ -18,6 +19,8 @@
1819
from pydantic_ai.models import Model
1920
from pydantic_ai.usage import UsageLimits
2021

22+
from eval_protocol.pytest.types import RolloutProcessorConfig
23+
2124

2225
def setup_agent(joke_generation_model: Model, joke_selection_model: Model) -> Agent:
2326
"""
@@ -45,26 +48,32 @@ async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: # pyrig
4548
return joke_selection_agent
4649

4750

51+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
52+
joke_generation_model = OpenAIModel(
53+
config.completion_params["model"]["joke_generation_model"], provider="fireworks"
54+
)
55+
joke_selection_model = OpenAIModel(config.completion_params["model"]["joke_selection_model"], provider="fireworks")
56+
return setup_agent(
57+
joke_generation_model,
58+
joke_selection_model,
59+
)
60+
61+
4862
@pytest.mark.asyncio
4963
@evaluation_test(
5064
input_messages=[[[Message(role="user", content="Tell me a joke.")]]],
5165
completion_params=[
66+
# multi-agent
5267
{
5368
"model": {
54-
"joke_generation_model": {
55-
"model": "accounts/fireworks/models/kimi-k2-instruct",
56-
"provider": "fireworks",
57-
},
58-
"joke_selection_model": {"model": "accounts/fireworks/models/deepseek-v3p1", "provider": "fireworks"},
69+
"joke_generation_model": "accounts/fireworks/models/kimi-k2-instruct",
70+
"joke_selection_model": "accounts/fireworks/models/deepseek-v3p1",
5971
}
6072
},
6173
],
62-
rollout_processor=PydanticAgentRolloutProcessor(),
63-
rollout_processor_kwargs={
64-
"agent": setup_agent,
65-
# PydanticAgentRolloutProcessor will pass usage_limits into the "run" call
66-
"usage_limits": UsageLimits(request_limit=5, total_tokens_limit=1000),
67-
},
74+
rollout_processor=PydanticAgentRolloutProcessor(
75+
agent_factory, UsageLimits(request_limit=5, total_tokens_limit=1000)
76+
),
6877
mode="pointwise",
6978
)
7079
async def test_pydantic_multi_agent(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)