Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 17 additions & 56 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
# pyright: reportPrivateUsage=false

import asyncio
from collections.abc import Callable
import logging
import time
import types
from pydantic_ai.models import Model
from pydantic_ai.usage import UsageLimits
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.google import GoogleModel
from pydantic import TypeAdapter
from pydantic_ai.messages import ModelMessage
from pydantic_ai._utils import generate_tool_call_id
from pydantic_ai import Agent
from pydantic_ai._utils import generate_tool_call_id
from pydantic_ai.messages import ModelMessage
from pydantic_ai.messages import (
ModelRequest,
SystemPromptPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider

logger = logging.getLogger(__name__)
Expand All @@ -34,64 +31,29 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
"""Rollout processor for Pydantic AI agents. Mainly converts
EvaluationRow.messages to and from Pydantic AI ModelMessage format."""

def __init__(self):
def __init__(
self,
agent_factory: Callable[[RolloutProcessorConfig], Agent],
usage_limits: UsageLimits | None = None,
):
# dummy model used for its helper functions for processing messages
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self._setup_agent = agent_factory

@override
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""

semaphore = config.semaphore

# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
if "agent" not in config.kwargs:
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
if not isinstance(config.kwargs["agent"], Agent) and not isinstance(
config.kwargs["agent"], types.FunctionType
):
raise ValueError(
"kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
)

if isinstance(config.kwargs["agent"], types.FunctionType):
setup_agent = config.kwargs["agent"]
if not isinstance(config.completion_params["model"], dict):
raise ValueError(
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
)
kwargs: dict[str, Model] = {}
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
kwargs[k] = AnthropicModel(
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
)
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
kwargs[k] = GoogleModel(
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
)
else:
kwargs[k] = OpenAIModel(
v["model"], # pyright: ignore[reportUnknownArgumentType]
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
)
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
model = None
else:
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
model = OpenAIModel(
config.completion_params["model"], # pyright: ignore[reportAny]
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
)
agent = self._setup_agent(config)

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
start_time = time.perf_counter()

model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent_instance.run(
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
)
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())

# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
Expand All @@ -116,15 +78,15 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
return tasks

async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages)
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]

def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
if message.role == "assistant":
type_adapter = TypeAdapter(ChatCompletionMessage)
oai_message = type_adapter.validate_python(message)
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
return self.util._process_response(
return self._util._process_response(
ChatCompletion(
choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)],
object="chat.completion",
Expand Down Expand Up @@ -157,5 +119,4 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
)
]
)
else:
raise ValueError(f"Unknown role: {message.role}")
raise ValueError(f"Unknown role: {message.role}")
5 changes: 0 additions & 5 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,6 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
row.input_metadata.row_id = generate_id(seed=0, index=index)

completion_params = kwargs["completion_params"]
if completion_params and ("model" not in completion_params or not completion_params["model"]):
raise ValueError(
"No model provided. Please provide a model in the completion parameters object."
)

# Create eval metadata with test function info and current commit hash
eval_metadata = EvalMetadata(
name=test_func.__name__,
Expand Down
30 changes: 14 additions & 16 deletions tests/chinook/pydantic/test_pydantic_chinook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from eval_protocol.pytest import evaluation_test

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from tests.chinook.pydantic.agent import setup_agent
import os
from pydantic_ai.models.openai import OpenAIModel
Expand All @@ -20,21 +21,23 @@
)


def agent_factory(config: RolloutProcessorConfig) -> Agent:
model_name = config.completion_params["model"]
provider = config.completion_params["provider"]
model = OpenAIModel(model_name, provider=provider)
return setup_agent(model)


@pytest.mark.asyncio
@evaluation_test(
input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
Expand Down Expand Up @@ -92,16 +95,11 @@ class Response(BaseModel):
input_rows=[collect_dataset()],
completion_params=[
{
"model": {
"orchestrator_agent_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
}
}
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": setup_agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_complex_queries(row: EvaluationRow) -> EvaluationRow:
Expand Down
14 changes: 9 additions & 5 deletions tests/pytest/test_pydantic_agent.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from pydantic_ai.agent import Agent
from pydantic_ai.models.openai import OpenAIModel
import pytest

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from pydantic_ai import Agent

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

agent = Agent()

def agent_factory(config: RolloutProcessorConfig) -> Agent:
model = OpenAIModel(config.completion_params["model"], provider="fireworks")
return Agent(model=model)


@pytest.mark.asyncio
@evaluation_test(
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
completion_params=[
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
{"model": "accounts/fireworks/models/gpt-oss-120b"},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={"agent": agent},
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
mode="pointwise",
)
async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
Expand Down
39 changes: 24 additions & 15 deletions tests/pytest/test_pydantic_multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""
Copied and modified for eval-protocol from https://ai.pydantic.dev/multi-agent-applications/#agent-delegation

To test your Pydantic AI multi-agent application, you can pass a function that
sets up the agents and their tools. The function should accept parameters that
map a model to each agent. In completion_params, you can provide mappings of
model to agent based on key.
To test your Pydantic AI multi-agent application, you can pass a factory that
sets up the agenet based on the completion_params. The function should accept a
RolloutProcessorConfig. In completion_params, you can provide mappings of model
to agent based on key.
"""

from pydantic_ai.models.openai import OpenAIModel
import pytest

from eval_protocol.models import EvaluationRow, Message
Expand All @@ -18,6 +19,8 @@
from pydantic_ai.models import Model
from pydantic_ai.usage import UsageLimits

from eval_protocol.pytest.types import RolloutProcessorConfig


def setup_agent(joke_generation_model: Model, joke_selection_model: Model) -> Agent:
"""
Expand Down Expand Up @@ -45,26 +48,32 @@ async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: # pyrig
return joke_selection_agent


def agent_factory(config: RolloutProcessorConfig) -> Agent:
joke_generation_model = OpenAIModel(
config.completion_params["model"]["joke_generation_model"], provider="fireworks"
)
joke_selection_model = OpenAIModel(config.completion_params["model"]["joke_selection_model"], provider="fireworks")
return setup_agent(
joke_generation_model,
joke_selection_model,
)


@pytest.mark.asyncio
@evaluation_test(
input_messages=[[[Message(role="user", content="Tell me a joke.")]]],
completion_params=[
# multi-agent
{
"model": {
"joke_generation_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
"joke_selection_model": {"model": "accounts/fireworks/models/deepseek-v3p1", "provider": "fireworks"},
"joke_generation_model": "accounts/fireworks/models/kimi-k2-instruct",
"joke_selection_model": "accounts/fireworks/models/deepseek-v3p1",
}
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={
"agent": setup_agent,
# PydanticAgentRolloutProcessor will pass usage_limits into the "run" call
"usage_limits": UsageLimits(request_limit=5, total_tokens_limit=1000),
},
rollout_processor=PydanticAgentRolloutProcessor(
agent_factory, UsageLimits(request_limit=5, total_tokens_limit=1000)
),
mode="pointwise",
)
async def test_pydantic_multi_agent(row: EvaluationRow) -> EvaluationRow:
Expand Down
Loading