Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,15 +397,11 @@ def __iter__(self):

CompletionParams = Dict[str, Any]
"""
Common set of completion parameters that most model providers support in their
API. Set total=False to allow extra fields since LiteLLM + providers have their
own set of parameters. The following parameters are common fields that are
populated.

model: str
temperature: Optional[float]
max_tokens: Optional[int]
top_p: Optional[float]
The completion parameters for the respective LLM SDK or agent framework.
Depending on the rollout processor, this might be the parameters passed to
LiteLLM completion call or parameters for the "run" method of the "Agent" class
in Pydantic AI. You can also customize this dictionary to whatever you need if
you implement your own custom rollout processor.
"""


Expand Down
13 changes: 13 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig

# Conditional import for optional dependency
try:
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor

PYDANTIC_AI_AVAILABLE = True
except ImportError:
PYDANTIC_AI_AVAILABLE = False
PydanticAgentRolloutProcessor = None

__all__ = [
"AgentRolloutProcessor",
"MCPGymRolloutProcessor",
Expand All @@ -21,3 +30,7 @@
"BackoffConfig",
"get_default_exception_handler_config",
]

# Only add to __all__ if available
if PYDANTIC_AI_AVAILABLE:
__all__.append("PydanticAgentRolloutProcessor")
117 changes: 117 additions & 0 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import asyncio
import logging
from typing import List

from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice

from pydantic_ai.models.openai import OpenAIModel
from pydantic import TypeAdapter
from pydantic_ai.messages import ModelMessage
from pydantic_ai._utils import generate_tool_call_id
from pydantic_ai import Agent
from pydantic_ai.messages import (
ModelRequest,
SystemPromptPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.fireworks import FireworksProvider

logger = logging.getLogger(__name__)


class PydanticAgentRolloutProcessor(RolloutProcessor):
"""Rollout processor for Pydantic AI agents. Mainly converts
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""

def __init__(self):
# dummy model used for its helper functions for processing messages
self.util = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""

max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)

# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
if "agent" not in config.kwargs:
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
if not isinstance(config.kwargs["agent"], Agent):
raise ValueError("kwargs['agent'] must be a valid Pydantic AI Agent instance")

agent: Agent = config.kwargs["agent"]

model = OpenAIModel(
config.completion_params["model"],
provider=config.completion_params["provider"],
)

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent.run(message_history=model_messages, model=model)
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
return row

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result

# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks

async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
return [Message(**m) for m in oai_messages]

def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
if message.role == "assistant":
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
oai_message = type_adapter.validate_python(message)
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
return self.util._process_response(
ChatCompletion(
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
object="chat.completion",
model="",
id="",
created=(
int(row.created_at.timestamp())
if hasattr(row.created_at, "timestamp")
else int(row.created_at)
),
)
)
elif message.role == "user":
if isinstance(message.content, str):
return ModelRequest(parts=[UserPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
elif message.role == "system":
if isinstance(message.content, str):
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
elif isinstance(message.content, list):
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
elif message.role == "tool":
return ModelRequest(
parts=[
ToolReturnPart(
content=message.content,
tool_name="",
tool_call_id=message.tool_call_id or generate_tool_call_id(),
)
]
)
else:
raise ValueError(f"Unknown role: {message.role}")
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ bigquery = [
svgbench = [
"selenium>=4.0.0",
]
pydantic = [
"pydantic-ai",
]

[tool.pytest.ini_options]
addopts = "-q"
Expand Down Expand Up @@ -170,7 +173,6 @@ dev = [
"haikus==0.3.8",
"pytest>=8.4.1",
]

[tool.ruff]
line-length = 119
target-version = "py310"
Expand Down
27 changes: 27 additions & 0 deletions tests/pytest/test_pydantic_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import pytest

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from pydantic_ai import Agent

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor

agent = Agent()


@pytest.mark.asyncio
@evaluation_test(
input_messages=[Message(role="user", content="Hello, how are you?")],
completion_params=[
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": agent},
mode="pointwise",
)
async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
"""
Super simple hello world test for Pydantic AI.
"""
return row
Loading
Loading