|
9 | 9 | from collections.abc import Sequence |
10 | 10 |
|
11 | 11 | import pytest |
12 | | -from tqdm import tqdm |
13 | 12 |
|
14 | 13 | from eval_protocol.dataset_logger import default_logger |
15 | 14 | from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
|
58 | 57 | parse_ep_num_runs, |
59 | 58 | parse_ep_passed_threshold, |
60 | 59 | rollout_processor_with_retry, |
| 60 | + run_tasks_with_eval_progress, |
| 61 | + run_tasks_with_run_progress, |
61 | 62 | ) |
62 | 63 | from eval_protocol.utils.show_results_url import show_results_url |
63 | 64 |
|
@@ -380,37 +381,8 @@ async def _execute_groupwise_eval_with_semaphore( |
380 | 381 | asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row)) |
381 | 382 | ) |
382 | 383 |
|
383 | | - # Add tqdm progress bar for evaluations with proper cleanup |
384 | | - eval_position = run_idx + 2 # Position after rollout progress bar |
385 | | - with tqdm( |
386 | | - total=len(pointwise_tasks), |
387 | | - desc=f" Eval {run_idx + 1}", |
388 | | - unit="eval", |
389 | | - file=sys.__stderr__, |
390 | | - leave=False, |
391 | | - position=eval_position, |
392 | | - dynamic_ncols=True, |
393 | | - miniters=1, |
394 | | - mininterval=0.1, |
395 | | - bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", |
396 | | - ) as eval_pbar: |
397 | | - |
398 | | - async def task_with_progress(task): |
399 | | - try: |
400 | | - result = await task |
401 | | - return result |
402 | | - finally: |
403 | | - eval_pbar.update(1) |
404 | | - |
405 | | - wrapped_tasks = [task_with_progress(task) for task in pointwise_tasks] |
406 | | - try: |
407 | | - results = await asyncio.gather(*wrapped_tasks) |
408 | | - except Exception: |
409 | | - # Propagate cancellation to the real tasks and await them to quiesce |
410 | | - for task in pointwise_tasks: |
411 | | - task.cancel() |
412 | | - await asyncio.gather(*pointwise_tasks, return_exceptions=True) |
413 | | - raise |
| 384 | + # Run evaluation tasks with progress bar |
| 385 | + results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx) |
414 | 386 |
|
415 | 387 | all_results[run_idx] = results |
416 | 388 | elif mode == "groupwise": |
@@ -528,36 +500,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete |
528 | 500 | else: |
529 | 501 | # For other processors, create all tasks at once and run in parallel |
530 | 502 | # Concurrency is now controlled by the shared semaphore in each rollout processor |
531 | | - with tqdm( |
532 | | - total=num_runs, |
533 | | - desc="Runs (Parallel)", |
534 | | - unit="run", |
535 | | - file=sys.__stderr__, |
536 | | - position=0, |
537 | | - leave=True, |
538 | | - dynamic_ncols=True, |
539 | | - miniters=1, |
540 | | - bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", |
541 | | - ) as run_pbar: |
542 | | - |
543 | | - async def execute_run_with_progress(run_idx: int, config): |
544 | | - try: |
545 | | - result = await execute_run(run_idx, config) |
546 | | - return result |
547 | | - finally: |
548 | | - run_pbar.update(1) |
549 | | - |
550 | | - tasks = [] |
551 | | - for run_idx in range(num_runs): |
552 | | - tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config))) |
553 | | - try: |
554 | | - await asyncio.gather(*tasks) |
555 | | - except Exception: |
556 | | - # Propagate cancellation to tasks and await them to quiesce |
557 | | - for task in tasks: |
558 | | - task.cancel() |
559 | | - await asyncio.gather(*tasks, return_exceptions=True) |
560 | | - raise |
| 503 | + await run_tasks_with_run_progress(execute_run, num_runs, config) |
561 | 504 |
|
562 | 505 | experiment_duration_seconds = time.perf_counter() - experiment_start_time |
563 | 506 |
|
|
0 commit comments