diff --git a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py index 5bd6dfab..4c0edfc3 100644 --- a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py +++ b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py @@ -1,30 +1,27 @@ # pyright: reportPrivateUsage=false import asyncio +from collections.abc import Callable import logging import time -import types -from pydantic_ai.models import Model +from pydantic_ai.usage import UsageLimits from typing_extensions import override from eval_protocol.models import EvaluationRow, Message -from openai.types import CompletionUsage from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam from openai.types.chat.chat_completion import Choice as ChatCompletionChoice -from pydantic_ai.models.anthropic import AnthropicModel -from pydantic_ai.models.openai import OpenAIModel -from pydantic_ai.models.google import GoogleModel from pydantic import TypeAdapter -from pydantic_ai.messages import ModelMessage -from pydantic_ai._utils import generate_tool_call_id from pydantic_ai import Agent +from pydantic_ai._utils import generate_tool_call_id +from pydantic_ai.messages import ModelMessage from pydantic_ai.messages import ( ModelRequest, SystemPromptPart, ToolReturnPart, UserPromptPart, ) +from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider logger = logging.getLogger(__name__) @@ -34,9 +31,14 @@ class PydanticAgentRolloutProcessor(RolloutProcessor): """Rollout processor for Pydantic AI agents. Mainly converts EvaluationRow.messages to and from Pydantic AI ModelMessage format.""" - def __init__(self): + def __init__( + self, + agent_factory: Callable[[RolloutProcessorConfig], Agent], + usage_limits: UsageLimits | None = None, + ): # dummy model used for its helper functions for processing messages - self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy")) + self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy")) + self._setup_agent = agent_factory @override def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: @@ -44,54 +46,14 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> semaphore = config.semaphore - # validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict - if "agent" not in config.kwargs: - raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance") - if not isinstance(config.kwargs["agent"], Agent) and not isinstance( - config.kwargs["agent"], types.FunctionType - ): - raise ValueError( - "kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent" - ) - - if isinstance(config.kwargs["agent"], types.FunctionType): - setup_agent = config.kwargs["agent"] - if not isinstance(config.completion_params["model"], dict): - raise ValueError( - "completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)" - ) - kwargs: dict[str, Model] = {} - for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType] - if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType] - kwargs[k] = AnthropicModel( - v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] - ) - elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType] - kwargs[k] = GoogleModel( - v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] - ) - else: - kwargs[k] = OpenAIModel( - v["model"], # pyright: ignore[reportUnknownArgumentType] - provider=v["provider"], # pyright: ignore[reportUnknownArgumentType] - ) - agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny] - model = None - else: - agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType] - model = OpenAIModel( - config.completion_params["model"], # pyright: ignore[reportAny] - provider=config.completion_params["provider"], # pyright: ignore[reportAny] - ) + agent = self._setup_agent(config) async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" start_time = time.perf_counter() model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages] - response = await agent_instance.run( - message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits") - ) + response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits")) row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages()) # TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back. @@ -116,7 +78,7 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: return tasks async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]: - oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages) + oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages) return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType] def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage: @@ -124,7 +86,7 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow type_adapter = TypeAdapter(ChatCompletionMessage) oai_message = type_adapter.validate_python(message) # Fix: Provide required finish_reason and index, and ensure created is int (timestamp) - return self.util._process_response( + return self._util._process_response( ChatCompletion( choices=[ChatCompletionChoice(message=oai_message, finish_reason="stop", index=0)], object="chat.completion", @@ -157,5 +119,4 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow ) ] ) - else: - raise ValueError(f"Unknown role: {message.role}") + raise ValueError(f"Unknown role: {message.role}") diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index cf2c4a77..b58a062a 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -255,11 +255,6 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo row.input_metadata.row_id = generate_id(seed=0, index=index) completion_params = kwargs["completion_params"] - if completion_params and ("model" not in completion_params or not completion_params["model"]): - raise ValueError( - "No model provided. Please provide a model in the completion parameters object." - ) - # Create eval metadata with test function info and current commit hash eval_metadata = EvalMetadata( name=test_func.__name__, diff --git a/tests/chinook/pydantic/test_pydantic_chinook.py b/tests/chinook/pydantic/test_pydantic_chinook.py index 0fc33277..0233a227 100644 --- a/tests/chinook/pydantic/test_pydantic_chinook.py +++ b/tests/chinook/pydantic/test_pydantic_chinook.py @@ -6,6 +6,7 @@ from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig from tests.chinook.pydantic.agent import setup_agent import os from pydantic_ai.models.openai import OpenAIModel @@ -20,21 +21,23 @@ ) +def agent_factory(config: RolloutProcessorConfig) -> Agent: + model_name = config.completion_params["model"] + provider = config.completion_params["provider"] + model = OpenAIModel(model_name, provider=provider) + return setup_agent(model) + + @pytest.mark.asyncio @evaluation_test( input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_simple_query(row: EvaluationRow) -> EvaluationRow: @@ -92,16 +95,11 @@ class Response(BaseModel): input_rows=[collect_dataset()], completion_params=[ { - "model": { - "orchestrator_agent_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - } - } + "model": "accounts/fireworks/models/kimi-k2-instruct", + "provider": "fireworks", }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": setup_agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_complex_queries(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/pytest/test_pydantic_agent.py b/tests/pytest/test_pydantic_agent.py index 1c9079c1..d08f74c9 100644 --- a/tests/pytest/test_pydantic_agent.py +++ b/tests/pytest/test_pydantic_agent.py @@ -1,22 +1,26 @@ +from pydantic_ai.agent import Agent +from pydantic_ai.models.openai import OpenAIModel import pytest from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import evaluation_test -from pydantic_ai import Agent from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig -agent = Agent() + +def agent_factory(config: RolloutProcessorConfig) -> Agent: + model = OpenAIModel(config.completion_params["model"], provider="fireworks") + return Agent(model=model) @pytest.mark.asyncio @evaluation_test( input_messages=[[[Message(role="user", content="Hello, how are you?")]]], completion_params=[ - {"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"}, + {"model": "accounts/fireworks/models/gpt-oss-120b"}, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={"agent": agent}, + rollout_processor=PydanticAgentRolloutProcessor(agent_factory), mode="pointwise", ) async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/pytest/test_pydantic_multi_agent.py b/tests/pytest/test_pydantic_multi_agent.py index 9c0665ad..24be554c 100644 --- a/tests/pytest/test_pydantic_multi_agent.py +++ b/tests/pytest/test_pydantic_multi_agent.py @@ -1,12 +1,13 @@ """ Copied and modified for eval-protocol from https://ai.pydantic.dev/multi-agent-applications/#agent-delegation -To test your Pydantic AI multi-agent application, you can pass a function that -sets up the agents and their tools. The function should accept parameters that -map a model to each agent. In completion_params, you can provide mappings of -model to agent based on key. +To test your Pydantic AI multi-agent application, you can pass a factory that +sets up the agenet based on the completion_params. The function should accept a +RolloutProcessorConfig. In completion_params, you can provide mappings of model +to agent based on key. """ +from pydantic_ai.models.openai import OpenAIModel import pytest from eval_protocol.models import EvaluationRow, Message @@ -18,6 +19,8 @@ from pydantic_ai.models import Model from pydantic_ai.usage import UsageLimits +from eval_protocol.pytest.types import RolloutProcessorConfig + def setup_agent(joke_generation_model: Model, joke_selection_model: Model) -> Agent: """ @@ -45,26 +48,32 @@ async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: # pyrig return joke_selection_agent +def agent_factory(config: RolloutProcessorConfig) -> Agent: + joke_generation_model = OpenAIModel( + config.completion_params["model"]["joke_generation_model"], provider="fireworks" + ) + joke_selection_model = OpenAIModel(config.completion_params["model"]["joke_selection_model"], provider="fireworks") + return setup_agent( + joke_generation_model, + joke_selection_model, + ) + + @pytest.mark.asyncio @evaluation_test( input_messages=[[[Message(role="user", content="Tell me a joke.")]]], completion_params=[ + # multi-agent { "model": { - "joke_generation_model": { - "model": "accounts/fireworks/models/kimi-k2-instruct", - "provider": "fireworks", - }, - "joke_selection_model": {"model": "accounts/fireworks/models/deepseek-v3p1", "provider": "fireworks"}, + "joke_generation_model": "accounts/fireworks/models/kimi-k2-instruct", + "joke_selection_model": "accounts/fireworks/models/deepseek-v3p1", } }, ], - rollout_processor=PydanticAgentRolloutProcessor(), - rollout_processor_kwargs={ - "agent": setup_agent, - # PydanticAgentRolloutProcessor will pass usage_limits into the "run" call - "usage_limits": UsageLimits(request_limit=5, total_tokens_limit=1000), - }, + rollout_processor=PydanticAgentRolloutProcessor( + agent_factory, UsageLimits(request_limit=5, total_tokens_limit=1000) + ), mode="pointwise", ) async def test_pydantic_multi_agent(row: EvaluationRow) -> EvaluationRow: