Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions eval_protocol/pytest/default_pydantic_ai_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from eval_protocol.pytest.types import RolloutProcessorConfig
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice

from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.google import GoogleModel
from pydantic import TypeAdapter
from pydantic_ai.messages import ModelMessage
from pydantic_ai._utils import generate_tool_call_id
Expand Down Expand Up @@ -61,10 +62,19 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
)
kwargs = {}
for k, v in config.completion_params["model"].items():
kwargs[k] = OpenAIModel(
v["model"],
provider=v["provider"],
)
if v["model"] and v["model"].startswith("anthropic:"):
kwargs[k] = AnthropicModel(
v["model"].removeprefix("anthropic:"),
)
elif v["model"] and v["model"].startswith("google:"):
kwargs[k] = GoogleModel(
v["model"].removeprefix("google:"),
)
else:
kwargs[k] = OpenAIModel(
v["model"],
provider=v["provider"],
)
agent = setup_agent(**kwargs)
model = None
else:
Expand Down
33 changes: 29 additions & 4 deletions tests/chinook/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ def setup_agent(orchestrator_agent_model: Model):
introspection_result_str = "\n".join([",".join(map(str, item)) for item in introspection_result])

SYSTEM_PROMPT = f"""You are a helpful assistant that has access to the
Chinook database. You have access to a tool to execute SQL queries. Your job
is to answer questions about the database. Here is the schema of the database:
Chinook database stored in a Postgres database. You have access to a tool to
execute SQL queries that you should use to answer questions. Your job is to
answer questions about the database. If you run into an error, you should try to
fix the query and try again. Here is the schema of the database:

Schema:
table_name,column_name,data_type,is_nullable
Expand All @@ -26,10 +28,33 @@ def setup_agent(orchestrator_agent_model: Model):
)

@agent.tool(retries=5)
def execute_sql(ctx: RunContext, query: str) -> tuple[any, ...]:
def execute_sql(ctx: RunContext, query: str) -> dict:
try:
cursor.execute(query)
return cursor.fetchall()
# Get column headers from cursor description
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Get data rows
rows = cursor.fetchall()

if not columns or not rows:
return "No results found."

# Create markdown table
table_lines = []

# Header row
table_lines.append("| " + " | ".join(columns) + " |")

# Separator row
table_lines.append("| " + " | ".join(["---"] * len(columns)) + " |")

# Data rows
for row in rows:
# Convert all values to strings and escape pipes
formatted_row = [str(cell).replace("|", "\\|") if cell is not None else "" for cell in row]
table_lines.append("| " + " | ".join(formatted_row) + " |")

return "\n".join(table_lines)
except Exception as e:
connection.rollback()
raise ModelRetry("Please try again with a different query. Here is the error: " + str(e))
Expand Down
50 changes: 50 additions & 0 deletions tests/chinook/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List
import os
import glob

from eval_protocol.models import EvaluationRow, Message


def collect_dataset() -> List[EvaluationRow]:
"""
Iterate through the dataset folder and create EvaluationRow objects.

For each folder named "task_<n>", reads "task.txt" and "ground_truth.md"
and creates an EvaluationRow where:
- messages contains a user message with the task content
- ground_truth contains the contents of ground_truth.md
"""
dataset_rows = []
dataset_path = os.path.join(os.path.dirname(__file__), "dataset")

# Find all task folders (task_<n>)
task_folders = glob.glob(os.path.join(dataset_path, "task_*"))

for task_folder in sorted(task_folders):
task_name = os.path.basename(task_folder)

# Read task.txt
task_file = os.path.join(task_folder, "task.txt")
if not os.path.exists(task_file):
raise FileNotFoundError(f"Task file not found: {task_file}")

with open(task_file, "r", encoding="utf-8") as f:
task_content = f.read().strip()

# Read ground_truth.md
ground_truth_file = os.path.join(task_folder, "ground_truth.md")
if not os.path.exists(ground_truth_file):
raise FileNotFoundError(f"Ground truth file not found: {ground_truth_file}")

with open(ground_truth_file, "r", encoding="utf-8") as f:
ground_truth_content = f.read().strip()

# Create user message with the task
user_message = Message(role="user", content=task_content)

# Create EvaluationRow
evaluation_row = EvaluationRow(messages=[user_message], ground_truth=ground_truth_content)

dataset_rows.append(evaluation_row)

return dataset_rows
7 changes: 7 additions & 0 deletions tests/chinook/dataset/task_1/ground_truth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
| customer_name | favorite_genre | total_invoices | total_spent | spending_rank |
| ------------------ | -------------- | -------------- | ----------- | ------------- |
| Helena Holý | Rock | 7 | 49.62 | 1 |
| Richard Cunningham | Rock | 7 | 47.62 | 2 |
| Luis Rojas | Rock | 7 | 46.62 | 3 |
| Ladislav Kovács | Rock | 7 | 45.62 | 4 |
| Hugh O'Reilly | Rock | 7 | 45.62 | 4 |
1 change: 1 addition & 0 deletions tests/chinook/dataset/task_1/task.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Find the top 5 customers by total spending, including their favorite genre. Show customer name, favorite genre, total invoices, total spent, and spending rank.
10 changes: 10 additions & 0 deletions tests/chinook/dataset/task_2/ground_truth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
| name | level | hierarchy_path |
| ---------------- | ----- | -------------------------------------------------- |
| Andrew Adams | 0 | Andrew Adams |
| Michael Mitchell | 1 | Andrew Adams -> Michael Mitchell |
| Laura Callahan | 2 | Andrew Adams -> Michael Mitchell -> Laura Callahan |
| Robert King | 2 | Andrew Adams -> Michael Mitchell -> Robert King |
| Nancy Edwards | 1 | Andrew Adams -> Nancy Edwards |
| Jane Peacock | 2 | Andrew Adams -> Nancy Edwards -> Jane Peacock |
| Margaret Park | 2 | Andrew Adams -> Nancy Edwards -> Margaret Park |
| Steve Johnson | 2 | Andrew Adams -> Nancy Edwards -> Steve Johnson |
1 change: 1 addition & 0 deletions tests/chinook/dataset/task_2/task.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Find all employees and their reporting hierarchy levels using a recursive CTE. Show employee name, level, and the complete hierarchy path from top to bottom.
3 changes: 3 additions & 0 deletions tests/chinook/dataset/task_3/ground_truth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
| artist_name | genre_count | longest_track_name_length | total_tracks | most_popular_track |
| ------------- | ----------- | ------------------------- | ------------ | ------------------ |
| Lenny Kravitz | 3 | 32 | 57 | Mr. Cab Driver |
1 change: 1 addition & 0 deletions tests/chinook/dataset/task_3/task.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Find artists who have albums in multiple genres and their most popular track. Include genre count, longest track name, and total tracks. Only include artists who have at least one album with 'Greatest' in the title.
86 changes: 86 additions & 0 deletions tests/chinook/dataset/task_4/ground_truth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
| month | day_of_week | unique_customers | total_invoices | average_order_value | total_revenue | day_over_day_growth_percentage |
| ----- | ----------- | ---------------- | -------------- | ------------------- | ------------- | ------------------------------ |
| 1 | Friday | 2 | 2 | 5.45 | 10.89 | null |
| 1 | Monday | 7 | 7 | 6.22 | 43.57 | 300.09 |
| 1 | Saturday | 4 | 4 | 4.21 | 16.83 | -61.37 |
| 1 | Sunday | 4 | 4 | 3.47 | 13.86 | -17.65 |
| 1 | Thursday | 3 | 3 | 10.92 | 32.76 | 136.36 |
| 1 | Tuesday | 8 | 8 | 6.07 | 48.56 | 48.23 |
| 1 | Wednesday | 6 | 6 | 5.78 | 34.65 | -28.64 |
| 2 | Friday | 8 | 8 | 6.45 | 51.56 | null |
| 2 | Monday | 3 | 3 | 4.29 | 12.87 | -75.04 |
| 2 | Saturday | 5 | 5 | 6.34 | 31.69 | 146.23 |
| 2 | Sunday | 2 | 2 | 8.91 | 17.82 | -43.77 |
| 2 | Thursday | 7 | 7 | 6.08 | 42.57 | 138.89 |
| 2 | Tuesday | 4 | 4 | 4.21 | 16.83 | -60.47 |
| 2 | Wednesday | 4 | 4 | 3.47 | 13.86 | -17.65 |
| 3 | Friday | 4 | 4 | 5.21 | 20.83 | null |
| 3 | Monday | 8 | 8 | 5.70 | 45.56 | 118.72 |
| 3 | Saturday | 5 | 5 | 3.56 | 17.82 | -60.89 |
| 3 | Sunday | 8 | 8 | 6.06 | 48.51 | 172.22 |
| 3 | Thursday | 3 | 3 | 4.29 | 12.87 | -73.47 |
| 3 | Tuesday | 5 | 5 | 6.14 | 30.69 | 138.46 |
| 3 | Wednesday | 2 | 2 | 9.41 | 18.82 | -38.68 |
| 4 | Friday | 5 | 5 | 7.74 | 38.69 | null |
| 4 | Monday | 4 | 4 | 4.21 | 16.83 | -56.50 |
| 4 | Saturday | 2 | 2 | 8.91 | 17.82 | 5.88 |
| 4 | Sunday | 3 | 3 | 6.29 | 18.87 | 5.89 |
| 4 | Thursday | 6 | 6 | 6.60 | 39.60 | 109.86 |
| 4 | Tuesday | 5 | 5 | 3.56 | 17.82 | -55.00 |
| 4 | Wednesday | 8 | 8 | 6.06 | 48.51 | 172.22 |
| 5 | Friday | 5 | 5 | 3.56 | 17.82 | null |
| 5 | Monday | 5 | 5 | 7.14 | 35.69 | 100.28 |
| 5 | Saturday | 8 | 8 | 6.06 | 48.51 | 35.92 |
| 5 | Sunday | 6 | 6 | 6.60 | 39.60 | -18.37 |
| 5 | Thursday | 6 | 6 | 3.47 | 20.79 | -47.50 |
| 5 | Tuesday | 2 | 2 | 8.91 | 17.82 | -14.29 |
| 5 | Wednesday | 3 | 3 | 4.29 | 12.87 | -27.78 |
| 6 | Friday | 2 | 2 | 8.91 | 17.82 | null |
| 6 | Monday | 5 | 5 | 4.16 | 20.82 | 16.84 |
| 6 | Saturday | 3 | 3 | 4.29 | 12.87 | -38.18 |
| 6 | Sunday | 6 | 6 | 3.47 | 20.79 | 61.54 |
| 6 | Thursday | 5 | 5 | 6.54 | 32.69 | 57.24 |
| 6 | Tuesday | 8 | 8 | 6.69 | 53.51 | 63.69 |
| 6 | Wednesday | 6 | 6 | 7.10 | 42.60 | -20.39 |
| 7 | Friday | 8 | 8 | 6.06 | 48.51 | null |
| 7 | Monday | 2 | 2 | 8.91 | 17.82 | -63.27 |
| 7 | Saturday | 6 | 6 | 6.60 | 39.60 | 122.22 |
| 7 | Sunday | 5 | 5 | 6.14 | 30.69 | -22.50 |
| 7 | Thursday | 5 | 5 | 3.56 | 17.82 | -41.94 |
| 7 | Tuesday | 3 | 3 | 4.29 | 12.87 | -27.78 |
| 7 | Wednesday | 6 | 6 | 3.80 | 22.79 | 77.08 |
| 8 | Friday | 3 | 3 | 4.29 | 12.87 | null |
| 8 | Monday | 8 | 8 | 7.31 | 58.51 | 354.62 |
| 8 | Saturday | 6 | 6 | 3.47 | 20.79 | -64.47 |
| 8 | Sunday | 5 | 5 | 3.56 | 17.82 | -14.29 |
| 8 | Thursday | 2 | 2 | 8.91 | 17.82 | 0.00 |
| 8 | Tuesday | 6 | 6 | 6.60 | 39.60 | 122.22 |
| 8 | Wednesday | 5 | 5 | 6.14 | 30.69 | -22.50 |
| 9 | Friday | 6 | 6 | 7.43 | 44.60 | null |
| 9 | Monday | 3 | 3 | 4.29 | 12.87 | -71.14 |
| 9 | Saturday | 4 | 4 | 8.93 | 35.70 | 177.39 |
| 9 | Sunday | 2 | 2 | 8.91 | 17.82 | -50.08 |
| 9 | Thursday | 8 | 8 | 6.94 | 55.51 | 211.50 |
| 9 | Tuesday | 5 | 5 | 2.38 | 11.88 | -78.60 |
| 9 | Wednesday | 5 | 5 | 3.56 | 17.82 | 50.00 |
| 10 | Friday | 5 | 5 | 2.38 | 11.88 | null |
| 10 | Monday | 6 | 6 | 6.60 | 39.60 | 233.33 |
| 10 | Saturday | 6 | 6 | 3.14 | 18.81 | -52.50 |
| 10 | Sunday | 8 | 8 | 6.44 | 51.51 | 173.84 |
| 10 | Thursday | 3 | 3 | 4.29 | 12.87 | -75.01 |
| 10 | Tuesday | 5 | 5 | 8.12 | 40.61 | 215.54 |
| 10 | Wednesday | 2 | 2 | 8.91 | 17.82 | -56.12 |
| 11 | Friday | 4 | 4 | 6.19 | 24.75 | null |
| 11 | Monday | 5 | 5 | 2.38 | 11.88 | -52.00 |
| 11 | Saturday | 2 | 2 | 8.91 | 17.82 | 50.00 |
| 11 | Sunday | 3 | 3 | 4.29 | 12.87 | -27.78 |
| 11 | Thursday | 6 | 6 | 8.60 | 51.60 | 300.93 |
| 11 | Tuesday | 6 | 6 | 3.14 | 18.81 | -63.55 |
| 11 | Wednesday | 8 | 8 | 6.06 | 48.51 | 157.89 |
| 12 | Friday | 7 | 7 | 4.67 | 32.67 | null |
| 12 | Monday | 4 | 4 | 6.44 | 25.75 | -21.18 |
| 12 | Saturday | 8 | 8 | 6.06 | 48.51 | 88.39 |
| 12 | Sunday | 6 | 6 | 6.60 | 39.60 | -18.37 |
| 12 | Thursday | 5 | 5 | 2.38 | 11.88 | -70.00 |
| 12 | Tuesday | 2 | 2 | 8.91 | 17.82 | 50.00 |
| 12 | Wednesday | 3 | 3 | 4.29 | 12.87 | -27.78 |
1 change: 1 addition & 0 deletions tests/chinook/dataset/task_4/task.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Analyze customer purchasing patterns by month and day of week. Show month, day of week, unique customers, total invoices, average order value, total revenue, and day-over-day growth percentage for invoices from 2010 onwards.
Loading
Loading