Skip to content

Commit 81fbc70

Browse files
committed
resolve comments and fix bugs
1 parent f785514 commit 81fbc70

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -418,37 +418,42 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
418418
else:
419419
output_buffer = None
420420

421-
priority_results = await execute_priority_rollouts(
422-
dataset=data,
423-
num_runs=num_runs,
424-
rollout_processor=rollout_processor,
425-
config=config,
426-
max_concurrent_rollouts=max_concurrent_rollouts,
427-
active_logger=active_logger,
428-
eval_executor=test_func,
429-
max_concurrent_evaluations=max_concurrent_evaluations,
430-
mode=mode,
431-
micro_batch_data_buffer=output_buffer,
432-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
433-
)
421+
try:
422+
priority_results = await execute_priority_rollouts(
423+
dataset=data,
424+
num_runs=num_runs,
425+
rollout_processor=rollout_processor,
426+
config=config,
427+
max_concurrent_rollouts=max_concurrent_rollouts,
428+
active_logger=active_logger,
429+
eval_executor=test_func,
430+
max_concurrent_evaluations=max_concurrent_evaluations,
431+
mode=mode,
432+
micro_batch_data_buffer=output_buffer,
433+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
434+
)
435+
finally:
436+
if output_buffer:
437+
await output_buffer.close()
434438

435439
for res in priority_results:
436440
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
437441
if run_idx < len(all_results):
438442
all_results[run_idx].append(res)
439443

440444
processed_rows_in_run.append(res)
441-
postprocess(
442-
all_results,
443-
aggregation_method,
444-
passed_threshold,
445-
active_logger,
446-
mode,
447-
completion_params, # pyright: ignore[reportArgumentType]
448-
test_func.__name__,
449-
num_runs,
450-
time.perf_counter() - experiment_start_time,
451-
)
445+
446+
postprocess(
447+
all_results,
448+
aggregation_method,
449+
passed_threshold,
450+
active_logger,
451+
mode,
452+
completion_params, # pyright: ignore[reportArgumentType]
453+
test_func.__name__,
454+
num_runs,
455+
time.perf_counter() - experiment_start_time,
456+
)
452457

453458
else:
454459
async def execute_run(run_idx: int, config: RolloutProcessorConfig):

eval_protocol/pytest/priority_scheduler.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import logging
23
import os
34
from collections import defaultdict
45
from dataclasses import dataclass, field
@@ -67,7 +68,6 @@ def __init__(
6768
self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue()
6869

6970
# Concurrency Control
70-
self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts)
7171
self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations)
7272

7373
# Results storage
@@ -112,16 +112,13 @@ async def worker(self):
112112
Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any).
113113
"""
114114
while True:
115-
try:
116-
# Get a task from the priority queue
117-
task: RolloutTask = await self.queue.get()
118-
except asyncio.QueueEmpty:
119-
break
115+
# Get a task from the priority queue
116+
task: RolloutTask = await self.queue.get()
120117

121118
try:
122119
await self._process_task(task)
123120
except Exception as e:
124-
print(f"Error processing task for row {task.row.input_metadata.row_id}: {e}")
121+
logging.error(f"Error processing task for row {task.row.input_metadata.row_id}: {e}", exc_info=True)
125122
finally:
126123
self.queue.task_done()
127124

0 commit comments

Comments
 (0)