Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable
import logging
import time
from pydantic_ai.toolsets import FunctionToolset
from pydantic_ai.usage import UsageLimits
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
Expand All @@ -21,7 +22,7 @@
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +38,7 @@ def __init__(
usage_limits: UsageLimits | None = None,
):
# dummy model used for its helper functions for processing messages
self._util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self._util: OpenAIChatModel = OpenAIChatModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
self._setup_agent = agent_factory

@override
Expand All @@ -53,18 +54,19 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()

tools = []
for _, tool in agent._function_tools.items():
tool_dict = {
"type": "function",
"function": {
"name": tool.name,
"parameters": tool.function_schema.json_schema,
},
}
if tool.description:
tool_dict["function"]["description"] = tool.description

tools.append(tool_dict)
for toolset in agent.toolsets:
if isinstance(toolset, FunctionToolset):
for _, tool in toolset.tools.items():
tool_dict = {
"type": "function",
"function": {
"name": tool.name,
"parameters": tool.function_schema.json_schema,
},
}
if tool.description:
tool_dict["function"]["description"] = tool.description
tools.append(tool_dict)
row.tools = tools

model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
Expand Down
4 changes: 2 additions & 2 deletions tests/chinook/langfuse/generate_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
try:
from langfuse import get_client, observe # pyright: ignore[reportPrivateImportUsage]
from pydantic_ai.agent import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel

LANGFUSE_AVAILABLE = True
langfuse_client = get_client()
Expand Down Expand Up @@ -42,7 +42,7 @@ def decorator(func):
def agent_factory(config: RolloutProcessorConfig) -> Agent:
model_name = config.completion_params["model"]
provider = config.completion_params["provider"]
model = OpenAIModel(model_name, provider=provider)
model = OpenAIChatModel(model_name, provider=provider)
return setup_agent(model)


Expand Down
4 changes: 2 additions & 2 deletions tests/chinook/langfuse/test_langfuse_chinook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata
from eval_protocol.pytest import evaluation_test, NoOpRolloutProcessor
Expand Down Expand Up @@ -99,7 +99,7 @@ async def test_langfuse_evaluation(row: EvaluationRow) -> EvaluationRow:
reason="No assistant message found",
)
else:
model = OpenAIModel(
model = OpenAIChatModel(
"accounts/fireworks/models/kimi-k2-instruct",
provider="fireworks",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/chinook/pydantic/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic_ai import Agent, RunContext
import asyncio
from pydantic_ai.models import Model
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.exceptions import ModelRetry
import sys
import os
Expand Down Expand Up @@ -68,7 +68,7 @@ def execute_sql(ctx: RunContext, query: str) -> str:


async def main():
model = OpenAIModel(
model = OpenAIChatModel(
"accounts/fireworks/models/kimi-k2-instruct",
provider="fireworks",
)
Expand Down
23 changes: 20 additions & 3 deletions tests/chinook/pydantic/test_pydantic_chinook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from eval_protocol.pytest.types import RolloutProcessorConfig
from tests.chinook.pydantic.agent import setup_agent
import os
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel

from tests.chinook.dataset import collect_dataset

Expand All @@ -24,7 +24,7 @@
def agent_factory(config: RolloutProcessorConfig) -> Agent:
model_name = config.completion_params["model"]
provider = config.completion_params["provider"]
model = OpenAIModel(model_name, provider=provider)
model = OpenAIChatModel(model_name, provider=provider)
return setup_agent(model)


Expand All @@ -44,6 +44,23 @@ async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
"""
Super simple query for the Chinook database
"""
expected_tools = [
{
"type": "function",
"function": {
"name": "execute_sql",
"parameters": {
"additionalProperties": False,
"properties": {"query": {"type": "string"}},
"required": ["query"],
"type": "object",
},
},
}
]
assert hasattr(row, "tools"), "Row missing 'tools' attribute"
assert row.tools == expected_tools, f"Tools validation failed. Expected: {expected_tools}, Got: {row.tools}"

last_assistant_message = row.last_assistant_message()
if last_assistant_message is None:
row.evaluation_result = EvaluateResult(
Expand All @@ -56,7 +73,7 @@ async def test_simple_query(row: EvaluationRow) -> EvaluationRow:
reason="No assistant message found",
)
else:
model = OpenAIModel(
model = OpenAIChatModel(
"accounts/fireworks/models/kimi-k2-instruct",
provider="fireworks",
)
Expand Down
6 changes: 3 additions & 3 deletions tests/chinook/pydantic/test_pydantic_complex_queries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel
import pytest

from eval_protocol.models import EvaluateResult, EvaluationRow
Expand All @@ -22,7 +22,7 @@
def agent_factory(config: RolloutProcessorConfig) -> Agent:
model_name = config.completion_params["model"]
provider = config.completion_params["provider"]
model = OpenAIModel(model_name, provider=provider)
model = OpenAIChatModel(model_name, provider=provider)
return setup_agent(model)


Expand Down Expand Up @@ -57,7 +57,7 @@ async def test_pydantic_complex_queries(row: EvaluationRow) -> EvaluationRow:
reason="No assistant message found",
)
else:
model = OpenAIModel(
model = OpenAIChatModel(
"accounts/fireworks/models/kimi-k2-instruct",
provider="fireworks",
)
Expand Down
7 changes: 4 additions & 3 deletions tests/pytest/test_pydantic_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from pydantic_ai.agent import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel
import pytest

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.models import EvaluationRow, Message, Status
from eval_protocol.pytest import evaluation_test

from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig


def agent_factory(config: RolloutProcessorConfig) -> Agent:
model = OpenAIModel(config.completion_params["model"], provider="fireworks")
model = OpenAIChatModel(config.completion_params["model"], provider="fireworks")
return Agent(model=model)


Expand All @@ -27,4 +27,5 @@ async def test_pydantic_agent(row: EvaluationRow) -> EvaluationRow:
"""
Super simple hello world test for Pydantic AI.
"""
assert row.rollout_status.code == Status.Code.FINISHED
return row
8 changes: 5 additions & 3 deletions tests/pytest/test_pydantic_multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
to agent based on key.
"""

from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.openai import OpenAIChatModel
import pytest

from eval_protocol.models import EvaluationRow, Message
Expand Down Expand Up @@ -49,10 +49,12 @@ async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: # pyrig


def agent_factory(config: RolloutProcessorConfig) -> Agent:
joke_generation_model = OpenAIModel(
joke_generation_model = OpenAIChatModel(
config.completion_params["model"]["joke_generation_model"], provider="fireworks"
)
joke_selection_model = OpenAIModel(config.completion_params["model"]["joke_selection_model"], provider="fireworks")
joke_selection_model = OpenAIChatModel(
config.completion_params["model"]["joke_selection_model"], provider="fireworks"
)
return setup_agent(
joke_generation_model,
joke_selection_model,
Expand Down
Loading
Loading