Skip to content

Commit 3d79e6f

Browse files
author
Dylan Huang
committed
refactor to its own file: test_pydantic_complex_queries
1 parent c6e0a5a commit 3d79e6f

File tree

3 files changed

+93
-67
lines changed

3 files changed

+93
-67
lines changed

tests/chinook/pydantic/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def setup_agent(orchestrator_agent_model: Model):
3333
)
3434

3535
@agent.tool(retries=5)
36-
def execute_sql(ctx: RunContext, query: str) -> dict:
36+
def execute_sql(ctx: RunContext, query: str) -> str:
3737
try:
3838
cursor.execute(query)
3939
# Get column headers from cursor description
@@ -69,7 +69,7 @@ def execute_sql(ctx: RunContext, query: str) -> dict:
6969

7070
async def main():
7171
model = OpenAIModel(
72-
model="accounts/fireworks/models/kimi-k2-instruct",
72+
"accounts/fireworks/models/kimi-k2-instruct",
7373
provider="fireworks",
7474
)
7575
agent = setup_agent(model)

tests/chinook/pydantic/test_pydantic_chinook.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -84,68 +84,3 @@ class Response(BaseModel):
8484
reason=result.output.reason,
8585
)
8686
return row
87-
88-
89-
@pytest.mark.skipif(
90-
os.environ.get("CI") == "true",
91-
reason="Only run this test locally (skipped in CI)",
92-
)
93-
@pytest.mark.asyncio
94-
@evaluation_test(
95-
input_rows=[collect_dataset()],
96-
completion_params=[
97-
{
98-
"model": "accounts/fireworks/models/kimi-k2-instruct",
99-
"provider": "fireworks",
100-
},
101-
],
102-
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
103-
mode="pointwise",
104-
)
105-
async def test_complex_queries(row: EvaluationRow) -> EvaluationRow:
106-
"""
107-
Complex queries for the Chinook database
108-
"""
109-
last_assistant_message = row.last_assistant_message()
110-
if last_assistant_message is None:
111-
row.evaluation_result = EvaluateResult(
112-
score=0.0,
113-
reason="No assistant message found",
114-
)
115-
elif not last_assistant_message.content:
116-
row.evaluation_result = EvaluateResult(
117-
score=0.0,
118-
reason="No assistant message found",
119-
)
120-
else:
121-
model = OpenAIModel(
122-
"accounts/fireworks/models/kimi-k2-instruct",
123-
provider="fireworks",
124-
)
125-
126-
class Response(BaseModel):
127-
"""
128-
A score between 0.0 and 1.0 indicating whether the response is correct.
129-
"""
130-
131-
score: float
132-
133-
"""
134-
A short explanation of why the response is correct or incorrect.
135-
"""
136-
reason: str
137-
138-
comparison_agent = Agent(
139-
model=model,
140-
system_prompt=LLM_JUDGE_PROMPT,
141-
output_type=Response,
142-
output_retries=5,
143-
)
144-
result = await comparison_agent.run(
145-
f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}"
146-
)
147-
row.evaluation_result = EvaluateResult(
148-
score=result.output.score,
149-
reason=result.output.reason,
150-
)
151-
return row
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os
2+
from pydantic import BaseModel
3+
from pydantic_ai import Agent
4+
from pydantic_ai.models.openai import OpenAIModel
5+
import pytest
6+
7+
from eval_protocol.models import EvaluateResult, EvaluationRow
8+
from eval_protocol.pytest import evaluation_test
9+
from eval_protocol.pytest.types import RolloutProcessorConfig
10+
from tests.chinook.dataset import collect_dataset
11+
from tests.chinook.pydantic.agent import setup_agent
12+
from tests.pytest.test_pydantic_agent import PydanticAgentRolloutProcessor
13+
14+
LLM_JUDGE_PROMPT = (
15+
"Your job is to compare the response to the expected answer.\n"
16+
"The response will be a narrative report of the query results.\n"
17+
"If the response contains the same or well summarized information as the expected answer, return 1.0.\n"
18+
"If the response does not contain the same information or is missing information, return 0.0."
19+
)
20+
21+
22+
def agent_factory(config: RolloutProcessorConfig) -> Agent:
23+
model_name = config.completion_params["model"]
24+
provider = config.completion_params["provider"]
25+
model = OpenAIModel(model_name, provider=provider)
26+
return setup_agent(model)
27+
28+
29+
@pytest.mark.skipif(
30+
os.environ.get("CI") == "true",
31+
reason="Only run this test locally (skipped in CI)",
32+
)
33+
@pytest.mark.asyncio
34+
@evaluation_test(
35+
input_rows=[collect_dataset()],
36+
completion_params=[
37+
{
38+
"model": "accounts/fireworks/models/kimi-k2-instruct",
39+
"provider": "fireworks",
40+
},
41+
],
42+
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
43+
mode="pointwise",
44+
)
45+
async def test_pydantic_complex_queries(row: EvaluationRow) -> EvaluationRow:
46+
"""
47+
Evaluation of complex queries for the Chinook database using PydanticAI
48+
"""
49+
last_assistant_message = row.last_assistant_message()
50+
if last_assistant_message is None:
51+
row.evaluation_result = EvaluateResult(
52+
score=0.0,
53+
reason="No assistant message found",
54+
)
55+
elif not last_assistant_message.content:
56+
row.evaluation_result = EvaluateResult(
57+
score=0.0,
58+
reason="No assistant message found",
59+
)
60+
else:
61+
model = OpenAIModel(
62+
"accounts/fireworks/models/kimi-k2-instruct",
63+
provider="fireworks",
64+
)
65+
66+
class Response(BaseModel):
67+
"""
68+
A score between 0.0 and 1.0 indicating whether the response is correct.
69+
"""
70+
71+
score: float
72+
73+
"""
74+
A short explanation of why the response is correct or incorrect.
75+
"""
76+
reason: str
77+
78+
comparison_agent = Agent(
79+
model=model,
80+
system_prompt=LLM_JUDGE_PROMPT,
81+
output_type=Response,
82+
output_retries=5,
83+
)
84+
result = await comparison_agent.run(
85+
f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}"
86+
)
87+
row.evaluation_result = EvaluateResult(
88+
score=result.output.score,
89+
reason=result.output.reason,
90+
)
91+
return row

0 commit comments

Comments
 (0)