Skip to content

Commit e7b09f9

Browse files
author
Dylan Huang
committed
save
1 parent fa5a70d commit e7b09f9

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
1414
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
1515
from pydantic import TypeAdapter
16-
from pydantic_ai import Agent
16+
from pydantic_ai import Agent, ModelSettings
1717
from pydantic_ai._utils import generate_tool_call_id
1818
from pydantic_ai.messages import ModelMessage
1919
from pydantic_ai.messages import (
@@ -46,7 +46,6 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
4646
"""Create agent rollout tasks and return them for external handling."""
4747

4848
semaphore = config.semaphore
49-
5049
agent = self._setup_agent(config)
5150

5251
async def process_row(row: EvaluationRow) -> EvaluationRow:
@@ -70,7 +69,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7069
row.tools = tools
7170

7271
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
73-
response = await agent.run(message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"))
72+
settings = self.construct_model_settings(agent, row)
73+
response = await agent.run(
74+
message_history=model_messages, usage_limits=config.kwargs.get("usage_limits"), model_settings=settings
75+
)
7476
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
7577

7678
# TODO: pydantic ai accumulates usage info across all models in multi-agent setup, so this simple tracking doesn't work for cost. to discuss with @dphuang2 when he's back.
@@ -98,6 +100,24 @@ async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage])
98100
oai_messages: list[ChatCompletionMessageParam] = await self._util._map_messages(messages)
99101
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
100102

103+
def construct_model_settings(self, agent: Agent, row: EvaluationRow) -> ModelSettings:
104+
model = agent.model
105+
if model and not isinstance(model, str) and model.settings:
106+
# We must copy model settings to avoid concurrency issues by modifying the same object in-place
107+
settings = model.settings.copy()
108+
if settings is None:
109+
settings = ModelSettings()
110+
settings["extra_body"] = settings.get("extra_body", {})
111+
extra_body = settings["extra_body"]
112+
if isinstance(extra_body, dict):
113+
extra_body["metadata"] = settings.get("metadata", {})
114+
extra_body["metadata"]["row_id"] = row.input_metadata.row_id
115+
extra_body["metadata"]["invocation_id"] = row.execution_metadata.invocation_id
116+
extra_body["metadata"]["rollout_id"] = row.execution_metadata.rollout_id
117+
extra_body["metadata"]["run_id"] = row.execution_metadata.run_id
118+
extra_body["metadata"]["experiment_id"] = row.execution_metadata.experiment_id
119+
return settings
120+
101121
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
102122
if message.role == "assistant":
103123
type_adapter = TypeAdapter(ChatCompletionMessage)

tests/chinook/pydantic/test_pydantic_complex_queries.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from pydantic import BaseModel
33
from pydantic_ai import Agent
4-
from pydantic_ai.models.openai import OpenAIChatModel
4+
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIChatModelSettings
55
import pytest
66

77
from eval_protocol.models import EvaluateResult, EvaluationRow
@@ -14,15 +14,18 @@
1414
LLM_JUDGE_PROMPT = (
1515
"Your job is to compare the response to the expected answer.\n"
1616
"The response will be a narrative report of the query results.\n"
17-
"If the response contains the same or well summarized information as the expected answer, return 1.0.\n"
18-
"If the response does not contain the same information or is missing information, return 0.0."
17+
"Return a score between 0.0 and 1.0, where 1.0 means the response contains all or well summarized information as the expected answer, "
18+
"0.0 means the response does not contain the same information or is missing all key information, "
19+
"and values in between represent partial credit for responses that are partially correct or contain some but not all of the expected information."
1920
)
2021

2122

2223
def agent_factory(config: RolloutProcessorConfig) -> Agent:
2324
model_name = config.completion_params["model"]
24-
provider = config.completion_params["provider"]
25-
model = OpenAIChatModel(model_name, provider=provider)
25+
provider = config.completion_params.get("provider")
26+
reasoning = config.completion_params.get("reasoning")
27+
settings = OpenAIChatModelSettings(openai_reasoning_effort=reasoning)
28+
model = OpenAIChatModel(model_name, provider=provider or "openai", settings=settings)
2629
return setup_agent(model)
2730

2831

@@ -38,8 +41,19 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent:
3841
"model": "accounts/fireworks/models/kimi-k2-instruct",
3942
"provider": "fireworks",
4043
},
44+
{
45+
"model": "accounts/fireworks/models/deepseek-v3p1",
46+
"provider": "fireworks",
47+
},
48+
{
49+
"model": "accounts/fireworks/models/kimi-k2-instruct-0905",
50+
"provider": "fireworks",
51+
},
52+
{"model": "gpt-5"},
53+
{"model": "gpt-5", "reasoning": "high"},
4154
],
4255
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
56+
num_runs=2,
4357
)
4458
async def test_pydantic_complex_queries(row: EvaluationRow) -> EvaluationRow:
4559
"""
@@ -58,8 +72,7 @@ async def test_pydantic_complex_queries(row: EvaluationRow) -> EvaluationRow:
5872
)
5973
else:
6074
model = OpenAIChatModel(
61-
"accounts/fireworks/models/kimi-k2-instruct",
62-
provider="fireworks",
75+
"gpt-5",
6376
)
6477

6578
class Response(BaseModel):

tests/chinook/pydantic/test_pydantic_complex_queries_responses.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
def agent_factory(config: RolloutProcessorConfig) -> Agent:
2121
model_name = config.completion_params["model"]
22-
reasoning = config.completion_params["reasoning"]
22+
reasoning = config.completion_params.get("reasoning")
2323
settings = OpenAIResponsesModelSettings(
2424
openai_reasoning_effort=reasoning,
2525
)
@@ -37,7 +37,6 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent:
3737
completion_params=[
3838
{
3939
"model": "gpt-5",
40-
"reasoning": "high",
4140
},
4241
],
4342
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),

0 commit comments

Comments
 (0)