Skip to content

Commit b556f4e

Browse files
committed
add
1 parent ff329d8 commit b556f4e

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

eval_protocol/pytest/buffer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
from eval_protocol.models import EvaluationRow
77

8-
class MiniBatchDataBuffer:
8+
class MicroBatchDataBuffer:
99
"""
1010
Buffers evaluation results and writes them to disk in minibatches.
1111
Waits for all runs of a sample to complete before considering it ready and flush to disk.
1212
"""
13-
def __init__(self, num_runs: int, minibatch_size: int, output_path_template: str):
13+
def __init__(self, num_runs: int, batch_size: int, output_path_template: str):
1414
self.num_runs = num_runs
15-
self.minibatch_size = minibatch_size
15+
self.batch_size = batch_size
1616
self.output_path_template = output_path_template
1717
self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow]
1818
self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]]
@@ -37,7 +37,7 @@ async def add_result(self, row: EvaluationRow):
3737
completed_rows = self.pending_samples.pop(row_id)
3838
self.completed_samples_buffer.append(completed_rows)
3939

40-
if len(self.completed_samples_buffer) >= self.minibatch_size:
40+
if len(self.completed_samples_buffer) >= self.batch_size:
4141
await self._flush_unsafe()
4242

4343
async def _flush_unsafe(self):

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
from eval_protocol.log_utils.init import init_external_logging_from_env
7171
from eval_protocol.log_utils.rollout_context import rollout_logging_context
7272
from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab
73-
from eval_protocol.pytest.buffer import MiniBatchDataBuffer
73+
from eval_protocol.pytest.buffer import MicroBatchDataBuffer
7474
from ..common_utils import load_jsonl
7575

7676

@@ -411,25 +411,24 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
411411
)
412412

413413
if use_priority_scheduler:
414-
print("Using priority scheduler")
415-
minibatch_output_size = os.environ.get("EP_MINI_BATCH_OUTPUT_SIZE", None)
414+
microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None)
416415
output_dir = os.environ.get("EP_OUTPUT_DIR", None)
417-
if minibatch_output_size and output_dir:
418-
output_buffer = MiniBatchDataBuffer(num_runs=num_runs, minibatch_size=int(minibatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
416+
if microbatch_output_size and output_dir:
417+
output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
419418
else:
420419
output_buffer = None
420+
421421
priority_results = await execute_priority_rollouts(
422422
dataset=data,
423423
num_runs=num_runs,
424-
micro_batch_size=int(os.environ.get("EP_MICRO_BATCH_SIZE", "1")),
425424
rollout_processor=rollout_processor,
426425
config=config,
427426
max_concurrent_rollouts=max_concurrent_rollouts,
428427
active_logger=active_logger,
429428
eval_executor=test_func,
430429
max_concurrent_evaluations=max_concurrent_evaluations,
431430
mode=mode,
432-
mini_batch_data_buffer=output_buffer,
431+
micro_batch_data_buffer=output_buffer,
433432
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
434433
)
435434

eval_protocol/pytest/priority_scheduler.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from eval_protocol.pytest.types import RolloutProcessorConfig, TestFunction
99
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1010
from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry, add_cost_metrics
11-
from eval_protocol.pytest.buffer import MiniBatchDataBuffer
11+
from eval_protocol.pytest.buffer import MicroBatchDataBuffer
1212
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1313
from eval_protocol.human_id import generate_id
1414
from eval_protocol.log_utils.rollout_context import rollout_logging_context
@@ -49,10 +49,10 @@ def __init__(
4949
active_logger: DatasetLogger,
5050
max_concurrent_evaluations: int,
5151
eval_executor: TestFunction, # Callback to run evaluation
52-
output_buffer: Optional[MiniBatchDataBuffer] = None,
52+
output_buffer: Optional[MicroBatchDataBuffer] = None,
5353
rollout_n: int = 0,
5454
mode: str = "pointwise",
55-
in_group_microbatch_size: int = 0, # for one sample, how many runs to execute at the same time
55+
in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time
5656
evaluation_test_kwargs: Dict[str, Any] = {},
5757
):
5858
self.rollout_processor = rollout_processor
@@ -77,7 +77,7 @@ def __init__(
7777
self.background_tasks = set() # run evaluations in the background asynchronously
7878

7979
self.rollout_n = rollout_n
80-
self.in_group_microbatch_size = in_group_microbatch_size if in_group_microbatch_size > 0 else rollout_n
80+
self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n
8181
self.evaluation_test_kwargs = evaluation_test_kwargs
8282

8383
async def schedule_dataset(
@@ -91,7 +91,7 @@ async def schedule_dataset(
9191
for i, row in enumerate(dataset):
9292
# Calculate ranges for the first in-group minibatch
9393
batch_start = 0
94-
batch_end = min(self.in_group_microbatch_size, self.rollout_n)
94+
batch_end = min(self.in_group_minibatch_size, self.rollout_n)
9595
run_indices = list(range(batch_start, batch_end))
9696

9797
# Initial priority: Low (1), ordered by dataset index
@@ -243,7 +243,7 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
243243
next_start = last_run_idx + 1
244244

245245
if next_start < self.rollout_n:
246-
next_end = min(next_start + self.in_group_microbatch_size, self.rollout_n)
246+
next_end = min(next_start + self.in_group_minibatch_size, self.rollout_n)
247247
next_indices = list(range(next_start, next_end))
248248
new_history = task.history + current_batch_history_updates
249249

@@ -327,27 +327,26 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz
327327
async def execute_priority_rollouts(
328328
dataset: List[EvaluationRow],
329329
num_runs: int,
330-
micro_batch_size: int,
331330
rollout_processor: RolloutProcessor,
332331
config: RolloutProcessorConfig,
333332
max_concurrent_rollouts: int,
334333
active_logger: DatasetLogger,
335334
eval_executor: TestFunction,
336335
max_concurrent_evaluations: int = 96,
337336
mode: str = "pointwise",
338-
mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None,
337+
micro_batch_data_buffer: Optional[MicroBatchDataBuffer] = None,
339338
evaluation_test_kwargs: Dict[str, Any] = {},
340339
):
341340
scheduler = PriorityRolloutScheduler(
342341
rollout_processor=rollout_processor,
343342
max_concurrent_rollouts=max_concurrent_rollouts,
344343
active_logger=active_logger,
345344
eval_executor=eval_executor,
346-
output_buffer=mini_batch_data_buffer,
345+
output_buffer=micro_batch_data_buffer,
347346
max_concurrent_evaluations=max_concurrent_evaluations,
348347
rollout_n=num_runs,
349348
mode=mode,
350-
in_group_microbatch_size=micro_batch_size,
349+
in_group_minibatch_size=(num_runs // 2),
351350
evaluation_test_kwargs=evaluation_test_kwargs,
352351
)
353352
return await scheduler.run(dataset, num_runs, micro_batch_size, config)

0 commit comments

Comments
 (0)