Skip to content

Commit 579d048

Browse files
xzrderekDylan Huang
andauthored
Arena hard auto inspired quick start (#170)
* Arena hard auto inspired quick start * working for my own chinook trace, changing adapter now * Responses api example (part 1) (#172) * pass through system message properly * save test_pydantic_complex_queries_responses * finished --------- Co-authored-by: Dylan Huang <dhuang@fireworks.ai>
1 parent 8101180 commit 579d048

File tree

9 files changed

+526
-354
lines changed

9 files changed

+526
-354
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 76 additions & 338 deletions
Large diffs are not rendered by default.

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3030
if len(row.messages) == 0:
3131
raise ValueError("Messages is empty. Please provide a non-empty dataset")
3232

33-
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
33+
messages_payload = [message.model_dump() for message in row.messages]
3434

3535
request_params = {"messages": messages_payload, **config.completion_params}
3636
# Ensure caching is disabled only for this request (review feedback)

eval_protocol/pytest/evaluation_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def evaluation_test(
8484
steps: int = 30,
8585
mode: EvaluationTestMode = "pointwise",
8686
combine_datasets: bool = True,
87+
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = None,
8788
logger: DatasetLogger | None = None,
8889
exception_handler_config: ExceptionHandlerConfig | None = None,
8990
) -> Callable[[TestFunction], TestFunction]:
@@ -150,6 +151,9 @@ def evaluation_test(
150151
mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
151152
"groupwise" applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo).
152153
"all" applies test function to the whole dataset.
154+
preprocess_fn: Optional preprocessing function that takes a list of EvaluationRow objects
155+
and returns a modified list. Useful for transformations like splitting multi-turn conversations,
156+
filtering data, or other preprocessing steps before rollout execution.
153157
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
154158
exception_handler_config: Configuration for exception handling and backoff retry logic.
155159
If not provided, a default configuration will be used with common retryable exceptions.
@@ -244,6 +248,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
244248
else:
245249
raise ValueError("No input dataset, input messages, or input rows provided")
246250

251+
if preprocess_fn:
252+
data = preprocess_fn(data)
253+
247254
for row in data:
248255
# generate a stable row_id for each row
249256
if row.input_metadata.row_id is None:
@@ -266,11 +273,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
266273
passed=None,
267274
)
268275
for row in data:
269-
# Only set completion_params if they don't already exist
270-
if not row.input_metadata.completion_params:
271-
row.input_metadata.completion_params = (
272-
completion_params if completion_params is not None else {}
273-
)
276+
row.input_metadata.completion_params = (
277+
completion_params if completion_params is not None else {}
278+
)
274279
# Add mode to session_data
275280
if row.input_metadata.session_data is None:
276281
row.input_metadata.session_data = {}

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,17 @@ def postprocess(
6262
passed = success_passed and standard_error_passed
6363

6464
# Update eval metadata passed field for all results
65-
for result in all_results:
66-
for r in result:
67-
if r.eval_metadata is not None:
68-
r.eval_metadata.passed = passed
69-
if r.evaluation_result is not None:
70-
r.evaluation_result.agg_score = agg_score
71-
r.evaluation_result.standard_error = standard_error
72-
r.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
73-
active_logger.log(r)
65+
for results in all_results:
66+
for result in results:
67+
if result.eval_metadata is not None:
68+
result.eval_metadata.passed = passed
69+
if result.evaluation_result is not None:
70+
if result.evaluation_result.agg_score is None:
71+
result.evaluation_result.agg_score = agg_score
72+
if result.evaluation_result.standard_error is None:
73+
result.evaluation_result.standard_error = standard_error
74+
result.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
75+
active_logger.log(result)
7476

7577
# Optional: print and/or persist a summary artifact for CI
7678
try:
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""
2+
Default LLM judge for Eval Protocol. Inspired by Arena-Hard-Auto.
3+
"""
4+
5+
import os
6+
from datetime import datetime, timedelta
7+
from typing import List, Dict, Any, Optional
8+
import pandas as pd
9+
from tqdm import tqdm
10+
11+
import pytest
12+
13+
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
14+
from eval_protocol.pytest import evaluation_test
15+
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
16+
from eval_protocol.quickstart.utils import pairwise_judgment, split_multi_turn_rows, serialize_message
17+
from eval_protocol.adapters.langfuse import create_langfuse_adapter
18+
19+
import concurrent.futures
20+
from concurrent.futures import ThreadPoolExecutor
21+
22+
JUDGE_CONFIGS = {
23+
"gpt-4.1": {
24+
"model": "gpt-4.1",
25+
"temperature": 0.0,
26+
"max_tokens": 16000,
27+
"max_concurrency": 64,
28+
},
29+
"gemini-2.5-pro": {
30+
"model": "gemini-2.5-pro",
31+
"temperature": 1.0,
32+
"max_tokens": 32000,
33+
"api_key": os.getenv("GEMINI_API_KEY"),
34+
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
35+
"max_concurrency": 32,
36+
},
37+
}
38+
39+
40+
def fetch_langfuse_traces_as_evaluation_rows(
41+
limit: int = 100,
42+
tags: Optional[List[str]] = None,
43+
user_id: Optional[str] = None,
44+
session_id: Optional[str] = None,
45+
hours_back: Optional[int] = None,
46+
include_tool_calls: bool = True,
47+
) -> List[EvaluationRow]:
48+
try:
49+
adapter = create_langfuse_adapter()
50+
51+
return adapter.get_evaluation_rows(
52+
limit=limit,
53+
tags=tags,
54+
user_id=user_id,
55+
session_id=session_id,
56+
hours_back=hours_back,
57+
include_tool_calls=include_tool_calls,
58+
)
59+
60+
except Exception as e:
61+
print(f"❌ LangfuseAdapter failed: {e}")
62+
return []
63+
64+
65+
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
66+
@pytest.mark.asyncio
67+
@evaluation_test(
68+
input_rows=[fetch_langfuse_traces_as_evaluation_rows()],
69+
completion_params=[
70+
{"model": "gpt-5"},
71+
{
72+
# "max_tokens": 131000,
73+
# "extra_body": {"reasoning_effort": "low"},
74+
"model": "fireworks_ai/accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
75+
},
76+
],
77+
rollout_processor=SingleTurnRolloutProcessor(),
78+
preprocess_fn=split_multi_turn_rows,
79+
mode="all",
80+
)
81+
async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]:
82+
"""
83+
Simplified LLM Judge for Arena-Hard-Auto style pairwise comparisons.
84+
85+
Each row contains:
86+
- messages[:-1]: Question/prompt (conversation context)
87+
- messages[-1]: Model B's answer (comparison model response)
88+
- ground_truth: Model A's answer (original assistant response)
89+
"""
90+
91+
judge_name = "gemini-2.5-pro" # Edit to which judge you'd like to use. Configs at top of file.
92+
93+
if not rows:
94+
print("❌ No evaluation rows provided")
95+
return rows
96+
97+
print(f"🔄 Processing {len(rows)} evaluation rows for LLM judging...")
98+
99+
model_name = rows[0].input_metadata.completion_params.get("model", "unknown_model")
100+
101+
def run_judgment(row: EvaluationRow) -> Optional[Dict[str, Any]]:
102+
"""Run pairwise judgment for a single evaluation row."""
103+
if not row.messages:
104+
return None
105+
106+
question_text = "\n".join([serialize_message(msg) for msg in row.messages[:-1]])
107+
model_a_answer = row.ground_truth
108+
model_b_answer = serialize_message(row.messages[-1])
109+
110+
games = []
111+
112+
# Round 1: A vs B (original vs comparison)
113+
result1 = pairwise_judgment(
114+
question_text=question_text,
115+
answer_a=model_a_answer,
116+
answer_b=model_b_answer,
117+
tools=row.tools,
118+
judge_config=JUDGE_CONFIGS[judge_name],
119+
)
120+
games.append(result1)
121+
122+
# Round 2: B vs A (comparison vs original)
123+
result2 = pairwise_judgment(
124+
question_text=question_text,
125+
answer_a=model_b_answer,
126+
answer_b=model_a_answer,
127+
tools=row.tools,
128+
judge_config=JUDGE_CONFIGS[judge_name],
129+
)
130+
games.append(result2)
131+
132+
row.evaluation_result = EvaluateResult(
133+
score=0.0,
134+
reason=f"LLM Judge comparison: Round 1: {result1['score']}, Round 2: {result2['score']}"
135+
if result1 and result2
136+
else "Failed to get judgement scores",
137+
metrics={
138+
"round1_judgment": MetricResult(
139+
score=0.0, reason=result1["judgment"] if result1 else "Failed to get judgment reason"
140+
),
141+
"round2_judgment": MetricResult(
142+
score=0.0, reason=result2["judgment"] if result2 else "Failed to get judgment reason"
143+
),
144+
},
145+
)
146+
147+
return {"model": model_name, "games": games}
148+
149+
judgments = []
150+
max_concurrency = JUDGE_CONFIGS[judge_name]["max_concurrency"]
151+
152+
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
153+
futures = [executor.submit(run_judgment, row) for row in rows]
154+
155+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating judgments"):
156+
result = future.result()
157+
if result and result["games"][0] and result["games"][1]:
158+
judgments.append(result)
159+
160+
if not judgments:
161+
print("❌ No valid judgments generated")
162+
return rows
163+
164+
print(f"✅ Generated {len(judgments)} valid judgments")
165+
166+
# Convert to scores for leaderboard
167+
label_to_score = {
168+
"A>B": [1],
169+
"A>>B": [1] * 3,
170+
"A=B": [0.5],
171+
"A<<B": [0] * 3,
172+
"A<B": [0],
173+
"B>A": [0],
174+
"B>>A": [0] * 3,
175+
"B=A": [0.5],
176+
"B<<A": [1] * 3,
177+
"B<A": [1],
178+
}
179+
180+
# Extract scores from judgments
181+
scores_data = []
182+
for judgment in judgments:
183+
game1, game2 = judgment["games"]
184+
if game1 and game2 and game1.get("score") and game2.get("score"):
185+
# Convert judgment scores to numerical scores
186+
scores = label_to_score[game2["score"]] + [1 - s for s in label_to_score[game1["score"]]]
187+
for score in scores:
188+
scores_data.append(score)
189+
190+
if not scores_data:
191+
print("❌ No valid scores extracted")
192+
return rows
193+
194+
# Create DataFrame (single column of scores)
195+
battles = pd.DataFrame({"score": scores_data})
196+
197+
# Bootstrap sampling for calculating relative performance to original model at fixed 50%
198+
bootstrap_means = [battles.sample(frac=1.0, replace=True)["score"].mean() for _ in range(100)]
199+
200+
# Calculate final scores
201+
bootstraps = pd.Series(bootstrap_means)
202+
mean_score = bootstraps.mean()
203+
lower_score = bootstraps.quantile(0.05)
204+
upper_score = bootstraps.quantile(0.95)
205+
206+
# Print leaderboard
207+
print("\n##### LLM Judge Results (90th percentile CI) #####")
208+
209+
clean_model_name = model_name.split("/")[-1] # Clean model name
210+
211+
print(f"{clean_model_name}: {mean_score:.1%} (CI: {lower_score:.1%} - {upper_score:.1%})")
212+
print("original: 50.0% (CI: 50.0% - 50.0%)")
213+
214+
for row in rows:
215+
if row.evaluation_result:
216+
row.evaluation_result.score = mean_score
217+
row.evaluation_result.standard_error = (upper_score - lower_score) / (
218+
2 * 1.645
219+
) # Standard error approximation from 90% CI
220+
221+
return rows

0 commit comments

Comments
 (0)