Skip to content

Commit 8b5c3c1

Browse files
authored
Update quickstart (#203)
* Update quickstart * update * update * update braintrust * test * remove * add new braintrust key * remove braintrust keys * remove id stuff and comment out braintrust * move tqdm to utils file
1 parent 9d7040d commit 8b5c3c1

File tree

7 files changed

+150
-65
lines changed

7 files changed

+150
-65
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import Sequence
1010

1111
import pytest
12-
from tqdm import tqdm
1312

1413
from eval_protocol.dataset_logger import default_logger
1514
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
@@ -58,6 +57,8 @@
5857
parse_ep_num_runs,
5958
parse_ep_passed_threshold,
6059
rollout_processor_with_retry,
60+
run_tasks_with_eval_progress,
61+
run_tasks_with_run_progress,
6162
)
6263
from eval_protocol.utils.show_results_url import store_local_ui_results_url
6364

@@ -382,7 +383,9 @@ async def _execute_groupwise_eval_with_semaphore(
382383
pointwise_tasks.append(
383384
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
384385
)
385-
results = await asyncio.gather(*pointwise_tasks)
386+
387+
# Run evaluation tasks with progress bar
388+
results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx)
386389

387390
all_results[run_idx] = results
388391
elif mode == "groupwise":
@@ -500,27 +503,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
500503
else:
501504
# For other processors, create all tasks at once and run in parallel
502505
# Concurrency is now controlled by the shared semaphore in each rollout processor
503-
with tqdm(
504-
total=num_runs,
505-
desc="Runs (Parallel)",
506-
unit="run",
507-
file=sys.__stderr__,
508-
position=0,
509-
leave=True,
510-
dynamic_ncols=True,
511-
miniters=1,
512-
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
513-
) as run_pbar:
514-
515-
async def execute_run_with_progress(run_idx: int, config):
516-
result = await execute_run(run_idx, config)
517-
run_pbar.update(1)
518-
return result
519-
520-
tasks = []
521-
for run_idx in range(num_runs):
522-
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
523-
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]
506+
await run_tasks_with_run_progress(execute_run, num_runs, config)
524507

525508
experiment_duration_seconds = time.perf_counter() - experiment_start_time
526509

eval_protocol/pytest/utils.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,89 @@
3333
AggregationMethod = Literal["mean", "max", "min", "bootstrap"]
3434

3535

36+
async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
37+
"""
38+
Run evaluation tasks with a progress bar and proper cancellation handling.
39+
40+
Args:
41+
pointwise_tasks: List of asyncio tasks to execute
42+
run_idx: Run index for progress bar positioning and naming
43+
44+
Returns:
45+
Results from all tasks
46+
"""
47+
eval_position = run_idx + 2 # Position after rollout progress bar
48+
with tqdm(
49+
total=len(pointwise_tasks),
50+
desc=f" Eval {run_idx + 1}",
51+
unit="eval",
52+
file=sys.__stderr__,
53+
leave=False,
54+
position=eval_position,
55+
dynamic_ncols=True,
56+
miniters=1,
57+
mininterval=0.1,
58+
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
59+
) as eval_pbar:
60+
61+
async def task_with_progress(task):
62+
try:
63+
result = await task
64+
return result
65+
finally:
66+
eval_pbar.update(1)
67+
68+
wrapped_tasks = [task_with_progress(task) for task in pointwise_tasks]
69+
try:
70+
results = await asyncio.gather(*wrapped_tasks)
71+
return results
72+
except Exception:
73+
# Propagate cancellation to the real tasks and await them to quiesce
74+
for task in pointwise_tasks:
75+
task.cancel()
76+
await asyncio.gather(*pointwise_tasks, return_exceptions=True)
77+
raise
78+
79+
80+
async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
81+
"""
82+
Run tasks with a parallel runs progress bar, preserving original logic.
83+
84+
Args:
85+
execute_run_func: The execute_run function to call
86+
num_runs: Number of runs to execute
87+
config: Configuration to pass to execute_run_func
88+
"""
89+
with tqdm(
90+
total=num_runs,
91+
desc="Runs (Parallel)",
92+
unit="run",
93+
file=sys.__stderr__,
94+
position=0,
95+
leave=True,
96+
dynamic_ncols=True,
97+
miniters=1,
98+
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
99+
) as run_pbar:
100+
101+
async def execute_run_with_progress(run_idx: int, config):
102+
result = await execute_run_func(run_idx, config)
103+
run_pbar.update(1)
104+
return result
105+
106+
tasks = []
107+
for run_idx in range(num_runs):
108+
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
109+
try:
110+
await asyncio.gather(*tasks)
111+
except Exception:
112+
# Propagate cancellation to tasks and await them to quiesce
113+
for task in tasks:
114+
task.cancel()
115+
await asyncio.gather(*tasks, return_exceptions=True)
116+
raise
117+
118+
36119
def calculate_bootstrap_scores(all_scores: list[float]) -> float:
37120
"""
38121
Calculate bootstrap confidence intervals for individual scores.
@@ -277,7 +360,7 @@ async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRo
277360
position = run_idx + 1 # Position 0 is reserved for main run bar, so shift up by 1
278361
with tqdm(
279362
total=len(retry_tasks),
280-
desc=f" Run {position}",
363+
desc=f" Run {run_idx + 1}",
281364
unit="rollout",
282365
file=sys.__stderr__,
283366
leave=False,

eval_protocol/quickstart/llm_judge.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,14 @@ async def aha_judge(
4747
model_a_answer = str(row.ground_truth)
4848
model_b_answer = serialize_message(row.messages[-1])
4949

50-
client = AsyncOpenAI(api_key=judge_config.get("api_key"), base_url=judge_config.get("base_url"))
51-
52-
# Run two judgment rounds in sequence (A vs B, then B vs A)
53-
result1 = await run_single_judgment(question_text, model_a_answer, model_b_answer, row.tools, judge_config, client)
54-
result2 = await run_single_judgment(question_text, model_b_answer, model_a_answer, row.tools, judge_config, client)
50+
async with AsyncOpenAI(api_key=judge_config.get("api_key"), base_url=judge_config.get("base_url")) as client:
51+
# Run two judgment rounds in sequence (A vs B, then B vs A)
52+
result1 = await run_single_judgment(
53+
question_text, model_a_answer, model_b_answer, row.tools, judge_config, client
54+
)
55+
result2 = await run_single_judgment(
56+
question_text, model_b_answer, model_a_answer, row.tools, judge_config, client
57+
)
5558

5659
if not result1 or not result2 or not result1.get("score") or not result2.get("score"):
5760
# If either judgment failed, mark as invalid (don't include in distribution)

eval_protocol/quickstart/llm_judge_braintrust.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,39 @@
66

77
import pytest
88

9+
# Skip entire module in CI to prevent import-time side effects
10+
if os.environ.get("CI") == "true":
11+
pytest.skip("Skip quickstart in CI", allow_module_level=True)
12+
913
from eval_protocol import (
1014
evaluation_test,
1115
aha_judge,
1216
multi_turn_assistant_to_ground_truth,
1317
EvaluationRow,
1418
SingleTurnRolloutProcessor,
1519
create_braintrust_adapter,
20+
DefaultParameterIdGenerator,
1621
)
22+
1723
# adapter = create_braintrust_adapter()
24+
# input_rows = [
25+
# adapter.get_evaluation_rows(
26+
# btql_query=f"""
27+
# select: *
28+
# from: project_logs('{os.getenv("BRAINTRUST_PROJECT_ID")}') traces
29+
# filter: is_root = true
30+
# limit: 10
31+
# """
32+
# )
33+
# ]
34+
input_rows = []
35+
# uncomment when dataloader is fixed
1836

1937

2038
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
21-
@pytest.mark.asyncio
22-
@evaluation_test(
23-
input_rows=[
24-
# adapter.get_evaluation_rows(
25-
# btql_query=f"""
26-
# select: *
27-
# from: project_logs('{os.getenv("BRAINTRUST_PROJECT_ID")}') traces
28-
# filter: is_root = true
29-
# limit: 10
30-
# """
31-
# )
32-
[]
33-
],
34-
completion_params=[
39+
@pytest.mark.parametrize(
40+
"completion_params",
41+
[
3542
{"model": "gpt-4.1"},
3643
{
3744
"max_tokens": 131000,
@@ -44,10 +51,12 @@
4451
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b",
4552
},
4653
],
54+
)
55+
@evaluation_test(
56+
input_rows=[input_rows],
4757
rollout_processor=SingleTurnRolloutProcessor(),
4858
preprocess_fn=multi_turn_assistant_to_ground_truth,
49-
max_concurrent_rollouts=64,
50-
aggregation_method="bootstrap",
59+
max_concurrent_evaluations=2,
5160
)
5261
async def test_llm_judge(row: EvaluationRow) -> EvaluationRow:
5362
return await aha_judge(row)

eval_protocol/quickstart/llm_judge_langfuse.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@
1414
EvaluationRow,
1515
SingleTurnRolloutProcessor,
1616
create_langfuse_adapter,
17+
DefaultParameterIdGenerator,
1718
)
1819

1920
from eval_protocol.quickstart import aha_judge
2021

2122
adapter = create_langfuse_adapter()
23+
input_rows = adapter.get_evaluation_rows(
24+
to_timestamp=datetime(2025, 9, 12, 0, 11, 18),
25+
limit=711,
26+
sample_size=50,
27+
sleep_between_gets=3.0,
28+
max_retries=5,
29+
)
2230

2331

2432
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
25-
@evaluation_test(
26-
input_rows=[
27-
adapter.get_evaluation_rows(
28-
to_timestamp=datetime(2025, 9, 12, 0, 11, 18),
29-
limit=711,
30-
sample_size=50,
31-
sleep_between_gets=3.0,
32-
max_retries=5,
33-
)
34-
],
35-
completion_params=[
33+
@pytest.mark.parametrize(
34+
"completion_params",
35+
[
3636
{"model": "gpt-4.1"},
3737
{
3838
"max_tokens": 131000,
@@ -45,10 +45,12 @@
4545
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b",
4646
},
4747
],
48+
)
49+
@evaluation_test(
50+
input_rows=[input_rows],
4851
rollout_processor=SingleTurnRolloutProcessor(),
4952
preprocess_fn=multi_turn_assistant_to_ground_truth,
50-
max_concurrent_rollouts=2,
51-
aggregation_method="bootstrap",
53+
max_concurrent_evaluations=2,
5254
)
5355
async def test_llm_judge(row: EvaluationRow) -> EvaluationRow:
5456
return await aha_judge(row)

eval_protocol/quickstart/llm_judge_langsmith.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
EvaluationRow,
3232
SingleTurnRolloutProcessor,
3333
LangSmithAdapter,
34+
DefaultParameterIdGenerator,
3435
)
3536

3637

@@ -53,11 +54,13 @@ def fetch_langsmith_traces_as_evaluation_rows(
5354
return []
5455

5556

57+
input_rows = fetch_langsmith_traces_as_evaluation_rows()
58+
59+
5660
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
57-
@pytest.mark.asyncio
58-
@evaluation_test(
59-
input_rows=[fetch_langsmith_traces_as_evaluation_rows()],
60-
completion_params=[
61+
@pytest.mark.parametrize(
62+
"completion_params",
63+
[
6164
{
6265
"model": "fireworks_ai/accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
6366
},
@@ -67,9 +70,12 @@ def fetch_langsmith_traces_as_evaluation_rows(
6770
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
6871
},
6972
],
73+
)
74+
@evaluation_test(
75+
input_rows=[input_rows],
7076
rollout_processor=SingleTurnRolloutProcessor(),
7177
preprocess_fn=multi_turn_assistant_to_ground_truth,
72-
aggregation_method="bootstrap",
78+
max_concurrent_evaluations=2,
7379
)
7480
async def test_llm_judge_langsmith(row: EvaluationRow) -> EvaluationRow:
7581
"""LLM Judge evaluation over LangSmith-sourced rows, persisted locally by Eval Protocol.

eval_protocol/quickstart/llm_judge_openai_responses.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,12 @@
5151
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905",
5252
},
5353
],
54-
ids=DefaultParameterIdGenerator.generate_id_from_dict,
5554
)
5655
@evaluation_test(
5756
input_rows=[input_rows],
5857
rollout_processor=SingleTurnRolloutProcessor(),
5958
preprocess_fn=multi_turn_assistant_to_ground_truth,
60-
aggregation_method="bootstrap",
59+
max_concurrent_evaluations=2,
6160
)
6261
async def test_llm_judge_openai_responses(row: EvaluationRow) -> EvaluationRow:
6362
return await aha_judge(row)

0 commit comments

Comments
 (0)