Skip to content

Commit 5f8eea1

Browse files
author
Dylan Huang
authored
more complex example for chinook (#124)
* more complex example * use kimi k2 for llm as a judge * add more tasks * improve LLM judge prompt * improve LLM judge prompt * add experiment/run/row id to table * allow evaluationtable to be sorted * vite build * vite build * reduce time for complex queries example * increase number of result retries * add task 6 * remove duplicate line * improve system prompt * fix more or less hooks before previous render * vite build * revert * support anthropic and google * skip in CI * sort by aggregate score * vite build
1 parent fcff843 commit 5f8eea1

29 files changed

+868
-148
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from eval_protocol.pytest.types import RolloutProcessorConfig
1212
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
1313
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
14-
14+
from pydantic_ai.models.anthropic import AnthropicModel
1515
from pydantic_ai.models.openai import OpenAIModel
16+
from pydantic_ai.models.google import GoogleModel
1617
from pydantic import TypeAdapter
1718
from pydantic_ai.messages import ModelMessage
1819
from pydantic_ai._utils import generate_tool_call_id
@@ -61,10 +62,19 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6162
)
6263
kwargs = {}
6364
for k, v in config.completion_params["model"].items():
64-
kwargs[k] = OpenAIModel(
65-
v["model"],
66-
provider=v["provider"],
67-
)
65+
if v["model"] and v["model"].startswith("anthropic:"):
66+
kwargs[k] = AnthropicModel(
67+
v["model"].removeprefix("anthropic:"),
68+
)
69+
elif v["model"] and v["model"].startswith("google:"):
70+
kwargs[k] = GoogleModel(
71+
v["model"].removeprefix("google:"),
72+
)
73+
else:
74+
kwargs[k] = OpenAIModel(
75+
v["model"],
76+
provider=v["provider"],
77+
)
6878
agent = setup_agent(**kwargs)
6979
model = None
7080
else:

tests/chinook/agent.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ def setup_agent(orchestrator_agent_model: Model):
1212
introspection_result_str = "\n".join([",".join(map(str, item)) for item in introspection_result])
1313

1414
SYSTEM_PROMPT = f"""You are a helpful assistant that has access to the
15-
Chinook database. You have access to a tool to execute SQL queries. Your job
16-
is to answer questions about the database. Here is the schema of the database:
15+
Chinook database stored in a Postgres database. You have access to a tool to
16+
execute SQL queries that you should use to answer questions. Your job is to
17+
answer questions about the database. If you run into an error, you should try to
18+
fix the query and try again. Here is the schema of the database:
1719
1820
Schema:
1921
table_name,column_name,data_type,is_nullable
@@ -26,10 +28,33 @@ def setup_agent(orchestrator_agent_model: Model):
2628
)
2729

2830
@agent.tool(retries=5)
29-
def execute_sql(ctx: RunContext, query: str) -> tuple[any, ...]:
31+
def execute_sql(ctx: RunContext, query: str) -> dict:
3032
try:
3133
cursor.execute(query)
32-
return cursor.fetchall()
34+
# Get column headers from cursor description
35+
columns = [desc[0] for desc in cursor.description] if cursor.description else []
36+
# Get data rows
37+
rows = cursor.fetchall()
38+
39+
if not columns or not rows:
40+
return "No results found."
41+
42+
# Create markdown table
43+
table_lines = []
44+
45+
# Header row
46+
table_lines.append("| " + " | ".join(columns) + " |")
47+
48+
# Separator row
49+
table_lines.append("| " + " | ".join(["---"] * len(columns)) + " |")
50+
51+
# Data rows
52+
for row in rows:
53+
# Convert all values to strings and escape pipes
54+
formatted_row = [str(cell).replace("|", "\\|") if cell is not None else "" for cell in row]
55+
table_lines.append("| " + " | ".join(formatted_row) + " |")
56+
57+
return "\n".join(table_lines)
3358
except Exception as e:
3459
connection.rollback()
3560
raise ModelRetry("Please try again with a different query. Here is the error: " + str(e))

tests/chinook/dataset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import List
2+
import os
3+
import glob
4+
5+
from eval_protocol.models import EvaluationRow, Message
6+
7+
8+
def collect_dataset() -> List[EvaluationRow]:
9+
"""
10+
Iterate through the dataset folder and create EvaluationRow objects.
11+
12+
For each folder named "task_<n>", reads "task.txt" and "ground_truth.md"
13+
and creates an EvaluationRow where:
14+
- messages contains a user message with the task content
15+
- ground_truth contains the contents of ground_truth.md
16+
"""
17+
dataset_rows = []
18+
dataset_path = os.path.join(os.path.dirname(__file__), "dataset")
19+
20+
# Find all task folders (task_<n>)
21+
task_folders = glob.glob(os.path.join(dataset_path, "task_*"))
22+
23+
for task_folder in sorted(task_folders):
24+
task_name = os.path.basename(task_folder)
25+
26+
# Read task.txt
27+
task_file = os.path.join(task_folder, "task.txt")
28+
if not os.path.exists(task_file):
29+
raise FileNotFoundError(f"Task file not found: {task_file}")
30+
31+
with open(task_file, "r", encoding="utf-8") as f:
32+
task_content = f.read().strip()
33+
34+
# Read ground_truth.md
35+
ground_truth_file = os.path.join(task_folder, "ground_truth.md")
36+
if not os.path.exists(ground_truth_file):
37+
raise FileNotFoundError(f"Ground truth file not found: {ground_truth_file}")
38+
39+
with open(ground_truth_file, "r", encoding="utf-8") as f:
40+
ground_truth_content = f.read().strip()
41+
42+
# Create user message with the task
43+
user_message = Message(role="user", content=task_content)
44+
45+
# Create EvaluationRow
46+
evaluation_row = EvaluationRow(messages=[user_message], ground_truth=ground_truth_content)
47+
48+
dataset_rows.append(evaluation_row)
49+
50+
return dataset_rows
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
| customer_name | favorite_genre | total_invoices | total_spent | spending_rank |
2+
| ------------------ | -------------- | -------------- | ----------- | ------------- |
3+
| Helena Holý | Rock | 7 | 49.62 | 1 |
4+
| Richard Cunningham | Rock | 7 | 47.62 | 2 |
5+
| Luis Rojas | Rock | 7 | 46.62 | 3 |
6+
| Ladislav Kovács | Rock | 7 | 45.62 | 4 |
7+
| Hugh O'Reilly | Rock | 7 | 45.62 | 4 |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Find the top 5 customers by total spending, including their favorite genre. Show customer name, favorite genre, total invoices, total spent, and spending rank.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
| name | level | hierarchy_path |
2+
| ---------------- | ----- | -------------------------------------------------- |
3+
| Andrew Adams | 0 | Andrew Adams |
4+
| Michael Mitchell | 1 | Andrew Adams -> Michael Mitchell |
5+
| Laura Callahan | 2 | Andrew Adams -> Michael Mitchell -> Laura Callahan |
6+
| Robert King | 2 | Andrew Adams -> Michael Mitchell -> Robert King |
7+
| Nancy Edwards | 1 | Andrew Adams -> Nancy Edwards |
8+
| Jane Peacock | 2 | Andrew Adams -> Nancy Edwards -> Jane Peacock |
9+
| Margaret Park | 2 | Andrew Adams -> Nancy Edwards -> Margaret Park |
10+
| Steve Johnson | 2 | Andrew Adams -> Nancy Edwards -> Steve Johnson |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Find all employees and their reporting hierarchy levels using a recursive CTE. Show employee name, level, and the complete hierarchy path from top to bottom.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
| artist_name | genre_count | longest_track_name_length | total_tracks | most_popular_track |
2+
| ------------- | ----------- | ------------------------- | ------------ | ------------------ |
3+
| Lenny Kravitz | 3 | 32 | 57 | Mr. Cab Driver |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Find artists who have albums in multiple genres and their most popular track. Include genre count, longest track name, and total tracks. Only include artists who have at least one album with 'Greatest' in the title.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
| month | day_of_week | unique_customers | total_invoices | average_order_value | total_revenue | day_over_day_growth_percentage |
2+
| ----- | ----------- | ---------------- | -------------- | ------------------- | ------------- | ------------------------------ |
3+
| 1 | Friday | 2 | 2 | 5.45 | 10.89 | null |
4+
| 1 | Monday | 7 | 7 | 6.22 | 43.57 | 300.09 |
5+
| 1 | Saturday | 4 | 4 | 4.21 | 16.83 | -61.37 |
6+
| 1 | Sunday | 4 | 4 | 3.47 | 13.86 | -17.65 |
7+
| 1 | Thursday | 3 | 3 | 10.92 | 32.76 | 136.36 |
8+
| 1 | Tuesday | 8 | 8 | 6.07 | 48.56 | 48.23 |
9+
| 1 | Wednesday | 6 | 6 | 5.78 | 34.65 | -28.64 |
10+
| 2 | Friday | 8 | 8 | 6.45 | 51.56 | null |
11+
| 2 | Monday | 3 | 3 | 4.29 | 12.87 | -75.04 |
12+
| 2 | Saturday | 5 | 5 | 6.34 | 31.69 | 146.23 |
13+
| 2 | Sunday | 2 | 2 | 8.91 | 17.82 | -43.77 |
14+
| 2 | Thursday | 7 | 7 | 6.08 | 42.57 | 138.89 |
15+
| 2 | Tuesday | 4 | 4 | 4.21 | 16.83 | -60.47 |
16+
| 2 | Wednesday | 4 | 4 | 3.47 | 13.86 | -17.65 |
17+
| 3 | Friday | 4 | 4 | 5.21 | 20.83 | null |
18+
| 3 | Monday | 8 | 8 | 5.70 | 45.56 | 118.72 |
19+
| 3 | Saturday | 5 | 5 | 3.56 | 17.82 | -60.89 |
20+
| 3 | Sunday | 8 | 8 | 6.06 | 48.51 | 172.22 |
21+
| 3 | Thursday | 3 | 3 | 4.29 | 12.87 | -73.47 |
22+
| 3 | Tuesday | 5 | 5 | 6.14 | 30.69 | 138.46 |
23+
| 3 | Wednesday | 2 | 2 | 9.41 | 18.82 | -38.68 |
24+
| 4 | Friday | 5 | 5 | 7.74 | 38.69 | null |
25+
| 4 | Monday | 4 | 4 | 4.21 | 16.83 | -56.50 |
26+
| 4 | Saturday | 2 | 2 | 8.91 | 17.82 | 5.88 |
27+
| 4 | Sunday | 3 | 3 | 6.29 | 18.87 | 5.89 |
28+
| 4 | Thursday | 6 | 6 | 6.60 | 39.60 | 109.86 |
29+
| 4 | Tuesday | 5 | 5 | 3.56 | 17.82 | -55.00 |
30+
| 4 | Wednesday | 8 | 8 | 6.06 | 48.51 | 172.22 |
31+
| 5 | Friday | 5 | 5 | 3.56 | 17.82 | null |
32+
| 5 | Monday | 5 | 5 | 7.14 | 35.69 | 100.28 |
33+
| 5 | Saturday | 8 | 8 | 6.06 | 48.51 | 35.92 |
34+
| 5 | Sunday | 6 | 6 | 6.60 | 39.60 | -18.37 |
35+
| 5 | Thursday | 6 | 6 | 3.47 | 20.79 | -47.50 |
36+
| 5 | Tuesday | 2 | 2 | 8.91 | 17.82 | -14.29 |
37+
| 5 | Wednesday | 3 | 3 | 4.29 | 12.87 | -27.78 |
38+
| 6 | Friday | 2 | 2 | 8.91 | 17.82 | null |
39+
| 6 | Monday | 5 | 5 | 4.16 | 20.82 | 16.84 |
40+
| 6 | Saturday | 3 | 3 | 4.29 | 12.87 | -38.18 |
41+
| 6 | Sunday | 6 | 6 | 3.47 | 20.79 | 61.54 |
42+
| 6 | Thursday | 5 | 5 | 6.54 | 32.69 | 57.24 |
43+
| 6 | Tuesday | 8 | 8 | 6.69 | 53.51 | 63.69 |
44+
| 6 | Wednesday | 6 | 6 | 7.10 | 42.60 | -20.39 |
45+
| 7 | Friday | 8 | 8 | 6.06 | 48.51 | null |
46+
| 7 | Monday | 2 | 2 | 8.91 | 17.82 | -63.27 |
47+
| 7 | Saturday | 6 | 6 | 6.60 | 39.60 | 122.22 |
48+
| 7 | Sunday | 5 | 5 | 6.14 | 30.69 | -22.50 |
49+
| 7 | Thursday | 5 | 5 | 3.56 | 17.82 | -41.94 |
50+
| 7 | Tuesday | 3 | 3 | 4.29 | 12.87 | -27.78 |
51+
| 7 | Wednesday | 6 | 6 | 3.80 | 22.79 | 77.08 |
52+
| 8 | Friday | 3 | 3 | 4.29 | 12.87 | null |
53+
| 8 | Monday | 8 | 8 | 7.31 | 58.51 | 354.62 |
54+
| 8 | Saturday | 6 | 6 | 3.47 | 20.79 | -64.47 |
55+
| 8 | Sunday | 5 | 5 | 3.56 | 17.82 | -14.29 |
56+
| 8 | Thursday | 2 | 2 | 8.91 | 17.82 | 0.00 |
57+
| 8 | Tuesday | 6 | 6 | 6.60 | 39.60 | 122.22 |
58+
| 8 | Wednesday | 5 | 5 | 6.14 | 30.69 | -22.50 |
59+
| 9 | Friday | 6 | 6 | 7.43 | 44.60 | null |
60+
| 9 | Monday | 3 | 3 | 4.29 | 12.87 | -71.14 |
61+
| 9 | Saturday | 4 | 4 | 8.93 | 35.70 | 177.39 |
62+
| 9 | Sunday | 2 | 2 | 8.91 | 17.82 | -50.08 |
63+
| 9 | Thursday | 8 | 8 | 6.94 | 55.51 | 211.50 |
64+
| 9 | Tuesday | 5 | 5 | 2.38 | 11.88 | -78.60 |
65+
| 9 | Wednesday | 5 | 5 | 3.56 | 17.82 | 50.00 |
66+
| 10 | Friday | 5 | 5 | 2.38 | 11.88 | null |
67+
| 10 | Monday | 6 | 6 | 6.60 | 39.60 | 233.33 |
68+
| 10 | Saturday | 6 | 6 | 3.14 | 18.81 | -52.50 |
69+
| 10 | Sunday | 8 | 8 | 6.44 | 51.51 | 173.84 |
70+
| 10 | Thursday | 3 | 3 | 4.29 | 12.87 | -75.01 |
71+
| 10 | Tuesday | 5 | 5 | 8.12 | 40.61 | 215.54 |
72+
| 10 | Wednesday | 2 | 2 | 8.91 | 17.82 | -56.12 |
73+
| 11 | Friday | 4 | 4 | 6.19 | 24.75 | null |
74+
| 11 | Monday | 5 | 5 | 2.38 | 11.88 | -52.00 |
75+
| 11 | Saturday | 2 | 2 | 8.91 | 17.82 | 50.00 |
76+
| 11 | Sunday | 3 | 3 | 4.29 | 12.87 | -27.78 |
77+
| 11 | Thursday | 6 | 6 | 8.60 | 51.60 | 300.93 |
78+
| 11 | Tuesday | 6 | 6 | 3.14 | 18.81 | -63.55 |
79+
| 11 | Wednesday | 8 | 8 | 6.06 | 48.51 | 157.89 |
80+
| 12 | Friday | 7 | 7 | 4.67 | 32.67 | null |
81+
| 12 | Monday | 4 | 4 | 6.44 | 25.75 | -21.18 |
82+
| 12 | Saturday | 8 | 8 | 6.06 | 48.51 | 88.39 |
83+
| 12 | Sunday | 6 | 6 | 6.60 | 39.60 | -18.37 |
84+
| 12 | Thursday | 5 | 5 | 2.38 | 11.88 | -70.00 |
85+
| 12 | Tuesday | 2 | 2 | 8.91 | 17.82 | 50.00 |
86+
| 12 | Wednesday | 3 | 3 | 4.29 | 12.87 | -27.78 |

0 commit comments

Comments
 (0)