Skip to content

Commit a53f31c

Browse files
author
Dylan Huang
authored
Super simple hello world test for Pydantic AI (#119)
* Super simple hello world test for Pydantic AI * update uv.lock * support FIREWORKS_API_KEY in env * fix issue
1 parent d951083 commit a53f31c

File tree

6 files changed

+491
-14
lines changed

6 files changed

+491
-14
lines changed

eval_protocol/models.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,11 @@ def __iter__(self):
397397

398398
CompletionParams = Dict[str, Any]
399399
"""
400-
Common set of completion parameters that most model providers support in their
401-
API. Set total=False to allow extra fields since LiteLLM + providers have their
402-
own set of parameters. The following parameters are common fields that are
403-
populated.
404-
405-
model: str
406-
temperature: Optional[float]
407-
max_tokens: Optional[int]
408-
top_p: Optional[float]
400+
The completion parameters for the respective LLM SDK or agent framework.
401+
Depending on the rollout processor, this might be the parameters passed to
402+
LiteLLM completion call or parameters for the "run" method of the "Agent" class
403+
in Pydantic AI. You can also customize this dictionary to whatever you need if
404+
you implement your own custom rollout processor.
409405
"""
410406

411407

eval_protocol/pytest/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from .rollout_processor import RolloutProcessor
99
from .types import RolloutProcessorConfig
1010

11+
# Conditional import for optional dependency
12+
try:
13+
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
14+
15+
PYDANTIC_AI_AVAILABLE = True
16+
except ImportError:
17+
PYDANTIC_AI_AVAILABLE = False
18+
PydanticAgentRolloutProcessor = None
19+
1120
__all__ = [
1221
"AgentRolloutProcessor",
1322
"MCPGymRolloutProcessor",
@@ -21,3 +30,7 @@
2130
"BackoffConfig",
2231
"get_default_exception_handler_config",
2332
]
33+
34+
# Only add to __all__ if available
35+
if PYDANTIC_AI_AVAILABLE:
36+
__all__.append("PydanticAgentRolloutProcessor")
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import asyncio
3+
import logging
4+
from typing import List
5+
6+
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
7+
8+
from eval_protocol.models import EvaluationRow, Message
9+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
10+
from eval_protocol.pytest.types import RolloutProcessorConfig
11+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
12+
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
13+
14+
from pydantic_ai.models.openai import OpenAIModel
15+
from pydantic import TypeAdapter
16+
from pydantic_ai.messages import ModelMessage
17+
from pydantic_ai._utils import generate_tool_call_id
18+
from pydantic_ai import Agent
19+
from pydantic_ai.messages import (
20+
ModelRequest,
21+
SystemPromptPart,
22+
ToolReturnPart,
23+
UserPromptPart,
24+
)
25+
from pydantic_ai.providers.openai import OpenAIProvider
26+
from pydantic_ai.providers.fireworks import FireworksProvider
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class PydanticAgentRolloutProcessor(RolloutProcessor):
32+
"""Rollout processor for Pydantic AI agents. Mainly converts
33+
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""
34+
35+
def __init__(self):
36+
# dummy model used for its helper functions for processing messages
37+
self.util = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
38+
39+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
40+
"""Create agent rollout tasks and return them for external handling."""
41+
42+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
43+
semaphore = asyncio.Semaphore(max_concurrent)
44+
45+
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
46+
if "agent" not in config.kwargs:
47+
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
48+
if not isinstance(config.kwargs["agent"], Agent):
49+
raise ValueError("kwargs['agent'] must be a valid Pydantic AI Agent instance")
50+
51+
agent: Agent = config.kwargs["agent"]
52+
53+
model = OpenAIModel(
54+
config.completion_params["model"],
55+
provider=config.completion_params["provider"],
56+
)
57+
58+
async def process_row(row: EvaluationRow) -> EvaluationRow:
59+
"""Process a single row with agent rollout."""
60+
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
61+
response = await agent.run(message_history=model_messages, model=model)
62+
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
63+
return row
64+
65+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
66+
async with semaphore:
67+
result = await process_row(r)
68+
return result
69+
70+
# Create and return tasks for external handling
71+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
72+
return tasks
73+
74+
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
75+
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
76+
return [Message(**m) for m in oai_messages]
77+
78+
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
79+
if message.role == "assistant":
80+
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
81+
oai_message = type_adapter.validate_python(message)
82+
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
83+
return self.util._process_response(
84+
ChatCompletion(
85+
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
86+
object="chat.completion",
87+
model="",
88+
id="",
89+
created=(
90+
int(row.created_at.timestamp())
91+
if hasattr(row.created_at, "timestamp")
92+
else int(row.created_at)
93+
),
94+
)
95+
)
96+
elif message.role == "user":
97+
if isinstance(message.content, str):
98+
return ModelRequest(parts=[UserPromptPart(content=message.content)])
99+
elif isinstance(message.content, list):
100+
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
101+
elif message.role == "system":
102+
if isinstance(message.content, str):
103+
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
104+
elif isinstance(message.content, list):
105+
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
106+
elif message.role == "tool":
107+
return ModelRequest(
108+
parts=[
109+
ToolReturnPart(
110+
content=message.content,
111+
tool_name="",
112+
tool_call_id=message.tool_call_id or generate_tool_call_id(),
113+
)
114+
]
115+
)
116+
else:
117+
raise ValueError(f"Unknown role: {message.role}")

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ bigquery = [
122122
svgbench = [
123123
"selenium>=4.0.0",
124124
]
125+
pydantic = [
126+
"pydantic-ai",
127+
]
125128

126129
[tool.pytest.ini_options]
127130
addopts = "-q"
@@ -170,7 +173,6 @@ dev = [
170173
"haikus==0.3.8",
171174
"pytest>=8.4.1",
172175
]
173-
174176
[tool.ruff]
175177
line-length = 119
176178
target-version = "py310"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import pytest
3+
4+
from eval_protocol.models import EvaluationRow, Message
5+
from eval_protocol.pytest import evaluation_test
6+
from pydantic_ai import Agent
7+
8+
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
9+
10+
agent = Agent()
11+
12+
13+
@pytest.mark.asyncio
14+
@evaluation_test(
15+
input_messages=[Message(role="user", content="Hello, how are you?")],
16+
completion_params=[
17+
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
18+
],
19+
rollout_processor=PydanticAgentRolloutProcessor(),
20+
rollout_processor_kwargs={"agent": agent},
21+
mode="pointwise",
22+
)
23+
async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
24+
"""
25+
Super simple hello world test for Pydantic AI.
26+
"""
27+
return row

0 commit comments

Comments
 (0)