Skip to content

Commit 579db3c

Browse files
committed
make example shorter
1 parent 23cca87 commit 579db3c

File tree

3 files changed

+199
-153
lines changed

3 files changed

+199
-153
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ def _extract_messages_from_trace(
220220
else:
221221
# Fallback: convert entire output to string
222222
messages.append(Message(role="assistant", content=str(trace.output)))
223+
elif isinstance(trace.output, list):
224+
# Direct list of message dicts (same as input handling)
225+
for msg in trace.output:
226+
messages.append(self._dict_to_message(msg, include_tool_calls))
223227
elif isinstance(trace.output, str):
224228
messages.append(Message(role="assistant", content=trace.output))
225229

eval_protocol/quickstart/llm_judge.py

Lines changed: 16 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -3,65 +3,26 @@
33
"""
44

55
import os
6-
from datetime import datetime, timedelta
76
from typing import List, Dict, Any, Optional
8-
import pandas as pd
97
from tqdm import tqdm
108

119
import pytest
1210

1311
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
1412
from eval_protocol.pytest import evaluation_test
1513
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
16-
from eval_protocol.quickstart.utils import pairwise_judgment, split_multi_turn_rows, serialize_message
17-
from eval_protocol.adapters.langfuse import create_langfuse_adapter
14+
from eval_protocol.quickstart.utils import (
15+
split_multi_turn_rows,
16+
JUDGE_CONFIGS,
17+
fetch_langfuse_traces_as_evaluation_rows,
18+
calculate_bootstrap_scores,
19+
push_scores_to_langfuse,
20+
run_judgment,
21+
)
1822

1923
import concurrent.futures
2024
from concurrent.futures import ThreadPoolExecutor
2125

22-
# Judge configs from the original Arena-Hard-Auto paper, feel free to add your own judge!
23-
JUDGE_CONFIGS = {
24-
"gpt-4.1": {
25-
"model": "gpt-4.1",
26-
"temperature": 0.0,
27-
"max_tokens": 16000,
28-
"max_concurrency": 64,
29-
},
30-
"gemini-2.5-pro": {
31-
"model": "gemini-2.5-pro",
32-
"temperature": 1.0,
33-
"max_tokens": 32000,
34-
"api_key": os.getenv("GEMINI_API_KEY"),
35-
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
36-
"max_concurrency": 32,
37-
},
38-
}
39-
40-
41-
def fetch_langfuse_traces_as_evaluation_rows(
42-
limit: int = 100,
43-
tags: Optional[List[str]] = None,
44-
user_id: Optional[str] = None,
45-
session_id: Optional[str] = None,
46-
hours_back: Optional[int] = None,
47-
include_tool_calls: bool = True,
48-
) -> List[EvaluationRow]:
49-
try:
50-
adapter = create_langfuse_adapter()
51-
52-
return adapter.get_evaluation_rows(
53-
limit=limit,
54-
tags=tags,
55-
user_id=user_id,
56-
session_id=session_id,
57-
hours_back=hours_back,
58-
include_tool_calls=include_tool_calls,
59-
)
60-
61-
except Exception as e:
62-
print(f"❌ LangfuseAdapter failed: {e}")
63-
return []
64-
6526

6627
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
6728
@pytest.mark.asyncio
@@ -83,15 +44,15 @@ def fetch_langfuse_traces_as_evaluation_rows(
8344
)
8445
async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]:
8546
"""
86-
Simplified LLM Judge for Arena-Hard-Auto style pairwise comparisons.
47+
Simplified LLM Judge for Arena-Hard-Auto pairwise comparisons.
8748
8849
Each row contains:
8950
- messages[:-1]: Question/prompt (conversation context)
9051
- messages[-1]: Model B's answer (comparison model response)
9152
- ground_truth: Model A's answer (original assistant response)
9253
"""
9354

94-
judge_name = "gemini-2.5-pro" # Edit to which judge you'd like to use. Configs at top of file.
55+
judge_name = "gemini-2.5-pro" # Edit to which judge you'd like to use. Configs are in utils.py.
9556

9657
if not rows:
9758
print("❌ No evaluation rows provided")
@@ -101,59 +62,11 @@ async def test_llm_judge(rows: list[EvaluationRow]) -> list[EvaluationRow]:
10162

10263
model_name = rows[0].input_metadata.completion_params.get("model", "unknown_model")
10364

104-
def run_judgment(row: EvaluationRow) -> Optional[Dict[str, Any]]:
105-
"""Run pairwise judgment for a single evaluation row."""
106-
if not row.messages:
107-
return None
108-
109-
question_text = "\n".join([serialize_message(msg) for msg in row.messages[:-1]])
110-
model_a_answer = row.ground_truth
111-
model_b_answer = serialize_message(row.messages[-1])
112-
113-
games = []
114-
115-
# Round 1: A vs B (original vs comparison)
116-
result1 = pairwise_judgment(
117-
question_text=question_text,
118-
answer_a=model_a_answer,
119-
answer_b=model_b_answer,
120-
tools=row.tools,
121-
judge_config=JUDGE_CONFIGS[judge_name],
122-
)
123-
games.append(result1)
124-
125-
# Round 2: B vs A (comparison vs original)
126-
result2 = pairwise_judgment(
127-
question_text=question_text,
128-
answer_a=model_b_answer,
129-
answer_b=model_a_answer,
130-
tools=row.tools,
131-
judge_config=JUDGE_CONFIGS[judge_name],
132-
)
133-
games.append(result2)
134-
135-
row.evaluation_result = EvaluateResult(
136-
score=0.0,
137-
reason=f"LLM Judge comparison: Round 1: {result1['score']}, Round 2: {result2['score']}"
138-
if result1 and result2
139-
else "Failed to get judgement scores",
140-
metrics={
141-
"round1_judgment": MetricResult(
142-
score=0.0, reason=result1["judgment"] if result1 else "Failed to get judgment reason"
143-
),
144-
"round2_judgment": MetricResult(
145-
score=0.0, reason=result2["judgment"] if result2 else "Failed to get judgment reason"
146-
),
147-
},
148-
)
149-
150-
return {"model": model_name, "games": games}
151-
15265
judgments = []
15366
max_concurrency = JUDGE_CONFIGS[judge_name]["max_concurrency"]
15467

15568
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
156-
futures = [executor.submit(run_judgment, row) for row in rows]
69+
futures = [executor.submit(run_judgment, row, model_name, judge_name) for row in rows]
15770

15871
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating judgments"):
15972
result = future.result()
@@ -166,46 +79,13 @@ def run_judgment(row: EvaluationRow) -> Optional[Dict[str, Any]]:
16679

16780
print(f"✅ Generated {len(judgments)} valid judgments")
16881

169-
# Convert to scores for leaderboard
170-
label_to_score = {
171-
"A>B": [1],
172-
"A>>B": [1] * 3,
173-
"A=B": [0.5],
174-
"A<<B": [0] * 3,
175-
"A<B": [0],
176-
"B>A": [0],
177-
"B>>A": [0] * 3,
178-
"B=A": [0.5],
179-
"B<<A": [1] * 3,
180-
"B<A": [1],
181-
}
182-
183-
# Extract scores from judgments
184-
scores_data = []
185-
for judgment in judgments:
186-
game1, game2 = judgment["games"]
187-
if game1 and game2 and game1.get("score") and game2.get("score"):
188-
# Convert judgment scores to numerical scores
189-
scores = label_to_score[game2["score"]] + [1 - s for s in label_to_score[game1["score"]]]
190-
for score in scores:
191-
scores_data.append(score)
192-
193-
if not scores_data:
82+
# Calculate bootstrap scores
83+
mean_score, lower_score, upper_score = calculate_bootstrap_scores(judgments)
84+
85+
if mean_score == 0.0:
19486
print("❌ No valid scores extracted")
19587
return rows
19688

197-
# Create DataFrame (single column of scores)
198-
battles = pd.DataFrame({"score": scores_data})
199-
200-
# Bootstrap sampling for calculating relative performance to original model at fixed 50%
201-
bootstrap_means = [battles.sample(frac=1.0, replace=True)["score"].mean() for _ in range(100)]
202-
203-
# Calculate final scores
204-
bootstraps = pd.Series(bootstrap_means)
205-
mean_score = bootstraps.mean()
206-
lower_score = bootstraps.quantile(0.05)
207-
upper_score = bootstraps.quantile(0.95)
208-
20989
# Print leaderboard
21090
print("\n##### LLM Judge Results (90th percentile CI) #####")
21191

@@ -222,22 +102,6 @@ def run_judgment(row: EvaluationRow) -> Optional[Dict[str, Any]]:
222102
) # Standard error approximation from 90% CI
223103

224104
# Optional, push scores back to Langfuse. Note that one score per model will be pushed back onto same trace.
225-
try:
226-
langfuse = create_langfuse_adapter().client
227-
except Exception:
228-
langfuse = None
229-
230-
if langfuse:
231-
for trace_id in set(
232-
row.input_metadata.session_data["langfuse_trace_id"]
233-
for row in rows
234-
if row.evaluation_result and row.input_metadata and row.input_metadata.session_data
235-
):
236-
if trace_id:
237-
langfuse.create_score(
238-
trace_id=trace_id,
239-
name=model_name,
240-
value=mean_score,
241-
)
105+
push_scores_to_langfuse(rows, model_name, mean_score)
242106

243107
return rows

0 commit comments

Comments
 (0)