Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import asyncio
import logging
import types
from typing import List

from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
Expand All @@ -23,7 +23,6 @@
UserPromptPart,
)
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.fireworks import FireworksProvider

logger = logging.getLogger(__name__)

Expand All @@ -45,20 +44,40 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
if "agent" not in config.kwargs:
raise ValueError("kwargs must contain an 'agent' field with a valid Pydantic AI Agent instance")
if not isinstance(config.kwargs["agent"], Agent):
raise ValueError("kwargs['agent'] must be a valid Pydantic AI Agent instance")

agent: Agent = config.kwargs["agent"]
if not isinstance(config.kwargs["agent"], Agent) and not isinstance(
config.kwargs["agent"], types.FunctionType
):
raise ValueError(
"kwargs['agent'] must be a valid Pydantic AI Agent instance or a function that returns an Agent"
)

model = OpenAIModel(
config.completion_params["model"],
provider=config.completion_params["provider"],
)
if isinstance(config.kwargs["agent"], types.FunctionType):
setup_agent = config.kwargs["agent"]
if not isinstance(config.completion_params["model"], dict):
raise ValueError(
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
)
kwargs = {}
for model_name, model_config in config.completion_params["model"].items():
kwargs[model_name] = OpenAIModel(
model_config["model"],
provider=model_config["provider"],
)
agent = setup_agent(**kwargs)
model = None
else:
agent = config.kwargs["agent"]
model = OpenAIModel(
config.completion_params["model"],
provider=config.completion_params["provider"],
)

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
response = await agent.run(message_history=model_messages, model=model)
response = await agent.run(
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
)
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
return row

Expand Down
74 changes: 74 additions & 0 deletions tests/pytest/test_pydantic_multi_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Copied and modified for eval-protocol from https://ai.pydantic.dev/multi-agent-applications/#agent-delegation

To test your Pydantic AI multi-agent application, you can pass a function that
sets up the agents and their tools. The function should accept parameters that
map a model to each agent. In completion_params, you can provide mappings of
model to agent based on key.
"""

import pytest

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from pydantic_ai import Agent

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
from pydantic_ai import RunContext
from pydantic_ai.models import Model
from pydantic_ai.usage import UsageLimits


def setup_agent(joke_generation_model: Model, joke_selection_model: Model) -> Agent:
"""
This is an extra step that most applications will probably need to do to
parameterize the model that their agents use. But we believe that this is a
necessary step for multi-agent applications if developers want to solve the
model selection problem.
"""
joke_selection_agent = Agent(
model=joke_selection_model,
system_prompt=(
"Use the `joke_factory` to generate some jokes, then choose the best. You must return just a single joke."
),
)
joke_generation_agent = Agent(joke_generation_model, output_type=list[str])

@joke_selection_agent.tool
async def joke_factory(ctx: RunContext[None], count: int) -> list[str]:
r = await joke_generation_agent.run(
f"Please generate {count} jokes.",
usage=ctx.usage,
)
return r.output

return joke_selection_agent


@pytest.mark.asyncio
@evaluation_test(
input_messages=[Message(role="user", content="Tell me a joke.")],
completion_params=[
{
"model": {
"joke_generation_model": {
"model": "accounts/fireworks/models/kimi-k2-instruct",
"provider": "fireworks",
},
"joke_selection_model": {"model": "accounts/fireworks/models/deepseek-v3p1", "provider": "fireworks"},
}
},
],
rollout_processor=PydanticAgentRolloutProcessor(),
rollout_processor_kwargs={
"agent": setup_agent,
# PydanticAgentRolloutProcessor will pass usage_limits into the "run" call
"usage_limits": UsageLimits(request_limit=5, total_tokens_limit=1000),
},
mode="pointwise",
)
async def test_pydantic_multi_agent(row: EvaluationRow) -> EvaluationRow:
"""
Super simple hello world test for Pydantic AI.
"""
return row
136 changes: 136 additions & 0 deletions vite-app/dist/assets/index-Bw6MHHaR.js

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions vite-app/dist/assets/index-Bw6MHHaR.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions vite-app/dist/assets/index-BxZNbf6w.css

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion vite-app/dist/assets/index-Bxmt9iUR.css

This file was deleted.

131 changes: 0 additions & 131 deletions vite-app/dist/assets/index-DbgWqpuZ.js

This file was deleted.

1 change: 0 additions & 1 deletion vite-app/dist/assets/index-DbgWqpuZ.js.map

This file was deleted.

4 changes: 2 additions & 2 deletions vite-app/dist/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>EP | Log Viewer</title>
<link rel="icon" href="/assets/favicon-BkAAWQga.png" />
<script type="module" crossorigin src="/assets/index-DbgWqpuZ.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-Bxmt9iUR.css">
<script type="module" crossorigin src="/assets/index-Bw6MHHaR.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-BxZNbf6w.css">
</head>
<body>
<div id="root"></div>
Expand Down
1 change: 1 addition & 0 deletions vite-app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"react-chartjs-2": "^5.3.0",
"react-dom": "^19.1.0",
"react-router-dom": "^7.7.1",
"react-tooltip": "^5.29.1",
"zod": "^4.0.14"
},
"devDependencies": {
Expand Down
41 changes: 41 additions & 0 deletions vite-app/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 24 additions & 3 deletions vite-app/src/components/EvaluationRow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { TableCell, TableRowInteractive } from "./TableContainer";
import { useState } from "react";
import type { FilterGroup, FilterConfig } from "../types/filters";
import { Tooltip } from "./Tooltip";
import { JSONTooltip } from "./JSONTooltip";

// Add filter button component
const AddFilterButton = observer(
Expand Down Expand Up @@ -190,9 +191,29 @@ const InvocationId = observer(({ invocationId }: { invocationId?: string }) => {
);
});

const RowModel = observer(({ model }: { model: string | undefined }) => (
<span className="text-gray-900 truncate block">{model || "N/A"}</span>
));
const RowModel = observer(
({ model }: { model: string | object | undefined }) => {
const displayValue = model
? typeof model === "string"
? model
: JSON.stringify(model)
: "N/A";

// For strings, show full value without tooltip
if (typeof model === "string" || !model) {
return <span className="text-gray-900 block">{displayValue}</span>;
}

// For objects, use JSONTooltip with truncation
return (
<JSONTooltip data={model}>
<span className="text-gray-900 truncate block max-w-[200px] cursor-help">
{displayValue}
</span>
</JSONTooltip>
);
}
);

const RowScore = observer(({ score }: { score: number | undefined }) => {
const scoreClass = score
Expand Down
63 changes: 63 additions & 0 deletions vite-app/src/components/JSONTooltip.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import React from "react";
import { Tooltip } from "react-tooltip";

interface JSONTooltipProps {
children: React.ReactNode;
data: any;
position?: "top" | "bottom" | "left" | "right";
className?: string;
}

export const JSONTooltip: React.FC<JSONTooltipProps> = ({
children,
data,
position = "top",
className = "",
}) => {
const tooltipId = `json-tooltip-${Math.random().toString(36).substr(2, 9)}`;
const formattedJSON = JSON.stringify(data, null, 2);

return (
<>
<div data-tooltip-id={tooltipId} className={`cursor-help ${className}`}>
{children}
</div>
<Tooltip
id={tooltipId}
place={position}
className="px-2 py-1 text-xs text-white bg-gray-800 rounded z-10 max-w-md"
style={{
fontSize: "0.75rem",
lineHeight: "1rem",
backgroundColor: "#1f2937",
color: "white",
borderRadius: "0.25rem",
padding: "0.5rem",
zIndex: 10,
userSelect: "text",
pointerEvents: "auto",
}}
delayShow={200}
delayHide={300}
clickable={true}
noArrow={true}
render={() => (
<pre
className="whitespace-pre-wrap text-left text-xs"
style={{
userSelect: "text",
pointerEvents: "auto",
cursor: "text",
}}
onMouseDown={(e) => e.stopPropagation()}
onClick={(e) => e.stopPropagation()}
onDoubleClick={(e) => e.stopPropagation()}
onContextMenu={(e) => e.stopPropagation()}
>
{formattedJSON}
</pre>
)}
/>
</>
);
};
Loading
Loading