Skip to content

Commit 824a998

Browse files
committed
move tqdm to utils file
1 parent 9200ecd commit 824a998

File tree

2 files changed

+88
-62
lines changed

2 files changed

+88
-62
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import Sequence
1010

1111
import pytest
12-
from tqdm import tqdm
1312

1413
from eval_protocol.dataset_logger import default_logger
1514
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
@@ -58,6 +57,8 @@
5857
parse_ep_num_runs,
5958
parse_ep_passed_threshold,
6059
rollout_processor_with_retry,
60+
run_tasks_with_eval_progress,
61+
run_tasks_with_run_progress,
6162
)
6263
from eval_protocol.utils.show_results_url import show_results_url
6364

@@ -380,37 +381,8 @@ async def _execute_groupwise_eval_with_semaphore(
380381
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
381382
)
382383

383-
# Add tqdm progress bar for evaluations with proper cleanup
384-
eval_position = run_idx + 2 # Position after rollout progress bar
385-
with tqdm(
386-
total=len(pointwise_tasks),
387-
desc=f" Eval {run_idx + 1}",
388-
unit="eval",
389-
file=sys.__stderr__,
390-
leave=False,
391-
position=eval_position,
392-
dynamic_ncols=True,
393-
miniters=1,
394-
mininterval=0.1,
395-
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
396-
) as eval_pbar:
397-
398-
async def task_with_progress(task):
399-
try:
400-
result = await task
401-
return result
402-
finally:
403-
eval_pbar.update(1)
404-
405-
wrapped_tasks = [task_with_progress(task) for task in pointwise_tasks]
406-
try:
407-
results = await asyncio.gather(*wrapped_tasks)
408-
except Exception:
409-
# Propagate cancellation to the real tasks and await them to quiesce
410-
for task in pointwise_tasks:
411-
task.cancel()
412-
await asyncio.gather(*pointwise_tasks, return_exceptions=True)
413-
raise
384+
# Run evaluation tasks with progress bar
385+
results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx)
414386

415387
all_results[run_idx] = results
416388
elif mode == "groupwise":
@@ -528,36 +500,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
528500
else:
529501
# For other processors, create all tasks at once and run in parallel
530502
# Concurrency is now controlled by the shared semaphore in each rollout processor
531-
with tqdm(
532-
total=num_runs,
533-
desc="Runs (Parallel)",
534-
unit="run",
535-
file=sys.__stderr__,
536-
position=0,
537-
leave=True,
538-
dynamic_ncols=True,
539-
miniters=1,
540-
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
541-
) as run_pbar:
542-
543-
async def execute_run_with_progress(run_idx: int, config):
544-
try:
545-
result = await execute_run(run_idx, config)
546-
return result
547-
finally:
548-
run_pbar.update(1)
549-
550-
tasks = []
551-
for run_idx in range(num_runs):
552-
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
553-
try:
554-
await asyncio.gather(*tasks)
555-
except Exception:
556-
# Propagate cancellation to tasks and await them to quiesce
557-
for task in tasks:
558-
task.cancel()
559-
await asyncio.gather(*tasks, return_exceptions=True)
560-
raise
503+
await run_tasks_with_run_progress(execute_run, num_runs, config)
561504

562505
experiment_duration_seconds = time.perf_counter() - experiment_start_time
563506

eval_protocol/pytest/utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,89 @@
3333
AggregationMethod = Literal["mean", "max", "min", "bootstrap"]
3434

3535

36+
async def run_tasks_with_eval_progress(pointwise_tasks: list, run_idx: int):
37+
"""
38+
Run evaluation tasks with a progress bar and proper cancellation handling.
39+
40+
Args:
41+
pointwise_tasks: List of asyncio tasks to execute
42+
run_idx: Run index for progress bar positioning and naming
43+
44+
Returns:
45+
Results from all tasks
46+
"""
47+
eval_position = run_idx + 2 # Position after rollout progress bar
48+
with tqdm(
49+
total=len(pointwise_tasks),
50+
desc=f" Eval {run_idx + 1}",
51+
unit="eval",
52+
file=sys.__stderr__,
53+
leave=False,
54+
position=eval_position,
55+
dynamic_ncols=True,
56+
miniters=1,
57+
mininterval=0.1,
58+
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
59+
) as eval_pbar:
60+
61+
async def task_with_progress(task):
62+
try:
63+
result = await task
64+
return result
65+
finally:
66+
eval_pbar.update(1)
67+
68+
wrapped_tasks = [task_with_progress(task) for task in pointwise_tasks]
69+
try:
70+
results = await asyncio.gather(*wrapped_tasks)
71+
return results
72+
except Exception:
73+
# Propagate cancellation to the real tasks and await them to quiesce
74+
for task in pointwise_tasks:
75+
task.cancel()
76+
await asyncio.gather(*pointwise_tasks, return_exceptions=True)
77+
raise
78+
79+
80+
async def run_tasks_with_run_progress(execute_run_func, num_runs, config):
81+
"""
82+
Run tasks with a parallel runs progress bar, preserving original logic.
83+
84+
Args:
85+
execute_run_func: The execute_run function to call
86+
num_runs: Number of runs to execute
87+
config: Configuration to pass to execute_run_func
88+
"""
89+
with tqdm(
90+
total=num_runs,
91+
desc="Runs (Parallel)",
92+
unit="run",
93+
file=sys.__stderr__,
94+
position=0,
95+
leave=True,
96+
dynamic_ncols=True,
97+
miniters=1,
98+
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
99+
) as run_pbar:
100+
101+
async def execute_run_with_progress(run_idx: int, config):
102+
result = await execute_run_func(run_idx, config)
103+
run_pbar.update(1)
104+
return result
105+
106+
tasks = []
107+
for run_idx in range(num_runs):
108+
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
109+
try:
110+
await asyncio.gather(*tasks)
111+
except Exception:
112+
# Propagate cancellation to tasks and await them to quiesce
113+
for task in tasks:
114+
task.cancel()
115+
await asyncio.gather(*tasks, return_exceptions=True)
116+
raise
117+
118+
36119
def calculate_bootstrap_scores(all_scores: list[float]) -> float:
37120
"""
38121
Calculate bootstrap confidence intervals for individual scores.

0 commit comments

Comments
 (0)