Skip to content

Commit da4023d

Browse files
committed
Arena hard auto inspired quick start
1 parent c8918fd commit da4023d

File tree

7 files changed

+353
-33
lines changed

7 files changed

+353
-33
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -155,24 +155,12 @@ def _convert_trace_to_evaluation_row(self, trace: Any, include_tool_calls: bool
155155
observations_response.data if hasattr(observations_response, "data") else list(observations_response)
156156
)
157157

158-
# Look for conversation history in trace output or observations
159158
messages = []
160-
conversation_found = False
161-
162-
# Look for complete conversation in observations
163-
if not conversation_found:
164-
for obs in observations:
165-
# Check each observation's output for complete conversation array
166-
if hasattr(obs, "output") and obs.output:
167-
conversation = self._extract_conversation_from_output(obs.output)
168-
if conversation:
169-
messages = conversation
170-
conversation_found = True
171-
break
172-
173-
# Fallback: try extracting from observations using old method
174-
if not conversation_found:
175-
messages = self._extract_messages_from_observations(observations, include_tool_calls)
159+
160+
for obs in observations:
161+
if obs.name == "agent run":
162+
messages = self._extract_conversation_from_output(obs.output)
163+
break
176164

177165
if not messages:
178166
return None
@@ -359,10 +347,16 @@ def _extract_conversation_from_output(self, output: Any) -> Optional[List[Messag
359347

360348
# Handle tool responses
361349
name = None
350+
tool_call_id = None
362351
if role == "tool":
363352
name = msg_data.get("name")
353+
tool_call_id = msg_data.get("id")
364354

365-
messages.append(Message(role=role, content=content, name=name, tool_calls=tool_calls))
355+
messages.append(
356+
Message(
357+
role=role, content=content, name=name, tool_calls=tool_calls, tool_call_id=tool_call_id
358+
)
359+
)
366360

367361
return messages if messages else None
368362

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3030
if len(row.messages) == 0:
3131
raise ValueError("Messages is empty. Please provide a non-empty dataset")
3232

33-
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
33+
messages_payload = [message.model_dump() for message in row.messages]
3434

3535
request_params = {"messages": messages_payload, **config.completion_params}
3636
# Ensure caching is disabled only for this request (review feedback)

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
parse_ep_num_runs,
5959
parse_ep_passed_threshold,
6060
rollout_processor_with_retry,
61+
split_multi_turn_rows,
6162
)
6263

6364
from ..common_utils import load_jsonl
@@ -84,6 +85,7 @@ def evaluation_test(
8485
steps: int = 30,
8586
mode: EvaluationTestMode = "pointwise",
8687
combine_datasets: bool = True,
88+
split_multi_turn: bool = False,
8789
logger: DatasetLogger | None = None,
8890
exception_handler_config: ExceptionHandlerConfig | None = None,
8991
) -> Callable[[TestFunction], TestFunction]:
@@ -150,6 +152,9 @@ def evaluation_test(
150152
mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
151153
"groupwise" applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo).
152154
"all" applies test function to the whole dataset.
155+
split_multi_turn: If True, splits multi-turn conversations into individual evaluation rows
156+
for each assistant response. Each row will contain the conversation context up to that point
157+
and the assistant's response as ground truth. Useful for Arena-Hard-Auto style evaluations.
153158
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
154159
exception_handler_config: Configuration for exception handling and backoff retry logic.
155160
If not provided, a default configuration will be used with common retryable exceptions.
@@ -244,6 +249,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
244249
else:
245250
raise ValueError("No input dataset, input messages, or input rows provided")
246251

252+
if split_multi_turn:
253+
data = split_multi_turn_rows(data)
254+
247255
for row in data:
248256
# generate a stable row_id for each row
249257
if row.input_metadata.row_id is None:
@@ -266,11 +274,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
266274
passed=None,
267275
)
268276
for row in data:
269-
# Only set completion_params if they don't already exist
270-
if not row.input_metadata.completion_params:
271-
row.input_metadata.completion_params = (
272-
completion_params if completion_params is not None else {}
273-
)
277+
row.input_metadata.completion_params = (
278+
completion_params if completion_params is not None else {}
279+
)
274280
# Add mode to session_data
275281
if row.input_metadata.session_data is None:
276282
row.input_metadata.session_data = {}

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,17 @@ def postprocess(
6262
passed = success_passed and standard_error_passed
6363

6464
# Update eval metadata passed field for all results
65-
for result in all_results:
66-
for r in result:
67-
if r.eval_metadata is not None:
68-
r.eval_metadata.passed = passed
69-
if r.evaluation_result is not None:
70-
r.evaluation_result.agg_score = agg_score
71-
r.evaluation_result.standard_error = standard_error
72-
r.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
73-
active_logger.log(r)
65+
for results in all_results:
66+
for result in results:
67+
if result.eval_metadata is not None:
68+
result.eval_metadata.passed = passed
69+
if result.evaluation_result is not None:
70+
if result.evaluation_result.agg_score is None:
71+
result.evaluation_result.agg_score = agg_score
72+
if result.evaluation_result.standard_error is None:
73+
result.evaluation_result.standard_error = standard_error
74+
result.execution_metadata.experiment_duration_seconds = experiment_duration_seconds
75+
active_logger.log(result)
7476

7577
# Optional: print and/or persist a summary artifact for CI
7678
try:

eval_protocol/pytest/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,42 @@ def add_cost_metrics(row: EvaluationRow) -> None:
352352
output_cost=output_cost,
353353
total_cost=total_cost,
354354
)
355+
356+
357+
def split_multi_turn_rows(data: list[EvaluationRow]) -> list[EvaluationRow]:
358+
"""
359+
Split multi-turn conversation rows into individual evaluation rows for each assistant message.
360+
361+
Args:
362+
data: List of EvaluationRow objects
363+
364+
Returns:
365+
List of expanded EvaluationRow objects, one for each assistant message
366+
"""
367+
expanded_rows = []
368+
369+
for row in data:
370+
messages = row.messages
371+
tools = row.tools
372+
input_metadata = row.input_metadata
373+
374+
assistant_positions = []
375+
for i, message in enumerate(messages):
376+
if message.role == "assistant":
377+
assistant_positions.append(i)
378+
379+
# Create separate evaluation rows on each assistant message (where the comparison model will respond)
380+
for assistant_pos in assistant_positions:
381+
messages_before_assistant = messages[:assistant_pos]
382+
ground_truth_message = messages[assistant_pos].content
383+
384+
expanded_rows.append(
385+
EvaluationRow(
386+
messages=messages_before_assistant,
387+
tools=tools,
388+
input_metadata=input_metadata,
389+
ground_truth=ground_truth_message,
390+
)
391+
)
392+
393+
return expanded_rows
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""
2+
Default LLM judge for Eval Protocol. Inspired by Arena-Hard-Auto.
3+
"""
4+
5+
import os
6+
from datetime import datetime, timedelta
7+
from typing import List, Dict, Any, Optional
8+
import pandas as pd
9+
from tqdm import tqdm
10+
11+
import pytest
12+
13+
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
14+
from eval_protocol.pytest import evaluation_test
15+
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
16+
from eval_protocol.quickstart.utils import pairwise_judgment
17+
18+
# Langfuse client setup
19+
try:
20+
from langfuse import get_client # pyright: ignore[reportPrivateImportUsage]
21+
22+
LANGFUSE_AVAILABLE = True
23+
langfuse = get_client()
24+
except ImportError:
25+
LANGFUSE_AVAILABLE = False
26+
langfuse = None
27+
28+
29+
def fetch_langfuse_traces_as_evaluation_rows(
30+
hours_back: int = 168, tags: Optional[List[str]] = None
31+
) -> List[EvaluationRow]:
32+
try:
33+
from eval_protocol.adapters.langfuse import create_langfuse_adapter
34+
35+
if not os.getenv("LANGFUSE_PUBLIC_KEY") or not os.getenv("LANGFUSE_SECRET_KEY"):
36+
raise ValueError("LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY must be set")
37+
38+
adapter = create_langfuse_adapter(
39+
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), # pyright: ignore[reportArgumentType]
40+
secret_key=os.getenv("LANGFUSE_SECRET_KEY"), # pyright: ignore[reportArgumentType]
41+
host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com"),
42+
)
43+
44+
now = datetime.now()
45+
from_timestamp = now - timedelta(hours=hours_back)
46+
47+
return adapter.get_evaluation_rows(
48+
limit=20, from_timestamp=from_timestamp, to_timestamp=now, include_tool_calls=True, tags=tags
49+
)
50+
51+
except Exception as e:
52+
print(f"❌ LangfuseAdapter failed: {e}")
53+
return []
54+
55+
56+
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
57+
@pytest.mark.asyncio
58+
@evaluation_test(
59+
input_rows=[fetch_langfuse_traces_as_evaluation_rows()],
60+
completion_params=[{"model": "gpt-4o"}],
61+
rollout_processor=SingleTurnRolloutProcessor(),
62+
split_multi_turn=True,
63+
mode="all",
64+
)
65+
async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]:
66+
"""
67+
Simplified LLM Judge for Arena-Hard-Auto style pairwise comparisons.
68+
69+
Each row contains:
70+
- messages[:-1]: Question/prompt (conversation context)
71+
- messages[-1]: Model B's answer (comparison model response)
72+
- ground_truth: Model A's answer (original assistant response)
73+
"""
74+
75+
if not rows:
76+
print("❌ No evaluation rows provided")
77+
return rows
78+
79+
print(f"🔄 Processing {len(rows)} evaluation rows for LLM judging...")
80+
81+
model_name = rows[0].input_metadata.completion_params.get("model", "unknown_model")
82+
83+
# Generate judgments directly from rows
84+
import concurrent.futures
85+
from concurrent.futures import ThreadPoolExecutor
86+
87+
def run_judgment(row: EvaluationRow) -> Optional[Dict[str, Any]]:
88+
"""Run pairwise judgment for a single evaluation row."""
89+
if not row.messages:
90+
return None
91+
92+
# Extract question and answers
93+
question_text = "\n".join([f"{msg.role}: {msg.content}" for msg in row.messages[:-1]])
94+
model_a_answer = row.ground_truth # Original response
95+
model_b_answer = row.messages[-1].content # Comparison model response
96+
97+
games = []
98+
99+
# Round 1: A vs B (original vs comparison)
100+
result1 = pairwise_judgment(
101+
question_text=question_text,
102+
answer_a=model_a_answer,
103+
answer_b=model_b_answer,
104+
)
105+
games.append(result1)
106+
107+
# Round 2: B vs A (comparison vs original)
108+
result2 = pairwise_judgment(
109+
question_text=question_text,
110+
answer_a=model_b_answer,
111+
answer_b=model_a_answer,
112+
)
113+
games.append(result2)
114+
115+
row.evaluation_result = EvaluateResult(
116+
score=0.0,
117+
reason=f"LLM Judge comparison: Round 1: {result1['score']}, Round 2: {result2['score']}"
118+
if result1 and result2
119+
else "Failed to get judgement scores",
120+
metrics={
121+
"round1_judgment": MetricResult(
122+
score=0.0, reason=result1["judgment"] if result1 else "Failed to get judgment reason"
123+
),
124+
"round2_judgment": MetricResult(
125+
score=0.0, reason=result2["judgment"] if result2 else "Failed to get judgment reason"
126+
),
127+
},
128+
)
129+
130+
return {"model": model_name, "games": games}
131+
132+
judgments = []
133+
max_workers = 64
134+
135+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
136+
futures = [executor.submit(run_judgment, row) for row in rows]
137+
138+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating judgments"):
139+
result = future.result()
140+
if result and result["games"][0] and result["games"][1]:
141+
judgments.append(result)
142+
143+
if not judgments:
144+
print("❌ No valid judgments generated")
145+
return rows
146+
147+
print(f"✅ Generated {len(judgments)} valid judgments")
148+
149+
# Convert to scores for leaderboard
150+
label_to_score = {
151+
"A>B": [1],
152+
"A>>B": [1] * 3,
153+
"A=B": [0.5],
154+
"A<<B": [0] * 3,
155+
"A<B": [0],
156+
"B>A": [0],
157+
"B>>A": [0] * 3,
158+
"B=A": [0.5],
159+
"B<<A": [1] * 3,
160+
"B<A": [1],
161+
}
162+
163+
# Extract scores from judgments
164+
scores_data = []
165+
for judgment in judgments:
166+
game1, game2 = judgment["games"]
167+
if game1 and game2 and game1.get("score") and game2.get("score"):
168+
# Convert judgment scores to numerical scores
169+
scores = label_to_score[game2["score"]] + [1 - s for s in label_to_score[game1["score"]]]
170+
for score in scores:
171+
scores_data.append(score)
172+
173+
if not scores_data:
174+
print("❌ No valid scores extracted")
175+
return rows
176+
177+
# Create DataFrame (single column of scores)
178+
battles = pd.DataFrame({"score": scores_data})
179+
180+
# Bootstrap sampling for calculating relative performance to original model at fixed 50%
181+
bootstrap_means = [
182+
battles.sample(frac=1.0, replace=True)["score"].mean() for _ in tqdm(range(100), desc="Bootstrap sampling")
183+
]
184+
185+
# Calculate final scores
186+
bootstraps = pd.Series(bootstrap_means)
187+
mean_score = bootstraps.mean()
188+
lower_score = bootstraps.quantile(0.05)
189+
upper_score = bootstraps.quantile(0.95)
190+
191+
# Print leaderboard
192+
print("\n##### LLM Judge Results (90th percentile CI) #####")
193+
194+
clean_model_name = model_name.split("/")[-1] # Clean model name
195+
196+
print(f"{clean_model_name}: {mean_score:.1%} (CI: {lower_score:.1%} - {upper_score:.1%})")
197+
print("original: 50.0% (CI: 50.0% - 50.0%)")
198+
199+
for row in rows:
200+
# This is hacky, but it's the only way to get the score into the evaluation result in our current pattern
201+
if row.evaluation_result:
202+
row.evaluation_result.score = mean_score
203+
# Standard error approximation from 90% CI: SE ≈ (upper - lower) / (2 × 1.645), but this is not quite right bc it assumes a normal distribution
204+
row.evaluation_result.standard_error = (upper_score - lower_score) / (2 * 1.645)
205+
206+
return rows

0 commit comments

Comments
 (0)