From d9ab3d483a28149f12b0bde38d2b3f14b552c301 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 27 Nov 2025 09:51:39 -0800 Subject: [PATCH 01/11] add --- eval_protocol/pytest/buffer.py | 76 ++++++ eval_protocol/pytest/evaluation_test.py | 1 + eval_protocol/pytest/priority_scheduler.py | 257 +++++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 eval_protocol/pytest/buffer.py create mode 100644 eval_protocol/pytest/priority_scheduler.py diff --git a/eval_protocol/pytest/buffer.py b/eval_protocol/pytest/buffer.py new file mode 100644 index 00000000..51771e62 --- /dev/null +++ b/eval_protocol/pytest/buffer.py @@ -0,0 +1,76 @@ +import asyncio +import os +from collections import defaultdict +from typing import List, Dict + +from eval_protocol.models import EvaluationRow + +class MiniBatchDataBuffer: + """ + Buffers evaluation results and writes them to disk in minibatches. + Waits for all runs of a sample to complete before considering it ready and flush to disk. + """ + def __init__(self, num_runs: int, minibatch_size: int, output_path_template: str): + self.num_runs = num_runs + self.minibatch_size = minibatch_size + self.output_path_template = output_path_template + self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow] + self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]] + self.batch_index = 0 + self.lock = asyncio.Lock() + + async def add_result(self, row: EvaluationRow): + """ + Add a single evaluation result. + Thread-safe/Coroutine-safe. + """ + async with self.lock: + row_id = row.input_metadata.row_id + if not row_id: + # Should not happen in valid EP workflow, unique row_id is required to group things together properly + return + + self.pending_samples[row_id].append(row) + + if len(self.pending_samples[row_id]) >= self.num_runs: + # Sample completed (all runs finished) + completed_rows = self.pending_samples.pop(row_id) + self.completed_samples_buffer.append(completed_rows) + + if len(self.completed_samples_buffer) >= self.minibatch_size: + await self._flush_unsafe() + + async def _flush_unsafe(self): + """ + not thread safe, assumes lock is held by called + """ + if not self.completed_samples_buffer: + return + + if "{index}" in self.output_path_template: + output_path = self.output_path_template.format(index=self.batch_index) + mode = "w" + else: + output_path = self.output_path_template + mode = "a" # Append if no index placeholder + + # Ensure directory exists + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + + # Write flattened rows + with open(output_path, mode) as f: + for sample_rows in self.completed_samples_buffer: + for row in sample_rows: + f.write(row.model_dump_json() + "\n") + + self.completed_samples_buffer = [] + self.batch_index += 1 + + async def close(self): + """ + Flush any remaining samples in the buffer. + """ + async with self.lock: + if self.completed_samples_buffer: + await self._flush_unsafe() + diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 67cc096e..db4a17a1 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import asyncio import inspect import os diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py new file mode 100644 index 00000000..19e5a827 --- /dev/null +++ b/eval_protocol/pytest/priority_scheduler.py @@ -0,0 +1,257 @@ +import asyncio +import os +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, List, Dict, Optional, Union, Awaitable + +from eval_protocol.models import EvaluationRow, Status +from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry +from eval_protocol.pytest.buffer import MiniBatchDataBuffer +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger +from eval_protocol.human_id import generate_id + +@dataclass(order=True) +class RolloutTask: + """ + Represents a single unit of work for the worker pool. + Priority tuple structure: (status, row_index) + - status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample) + 1 = Low Priority (e.g., starting a new sample) + - row_index: Used to maintain dataset order for initial scheduling + """ + priority: tuple[int, int] + + # Payload (excluded from comparison) + row: EvaluationRow = field(compare=False) + run_indices: List[int] = field(compare=False) # Which runs to execute in this task + config: RolloutProcessorConfig = field(compare=False) + row_index: int = field(compare=False) # To track which sample this belongs to + + # History for speculation (injected from previous micro-batches) + history: List[str] = field(compare=False, default_factory=list) + +class PriorityRolloutScheduler: + """ + Manages a priority queue of rollout tasks and a pool of workers. + Ensures that once a sample starts processing, its subsequent micro-batches + are prioritized to complete the sample as quickly as possible. + """ + def __init__( + self, + rollout_processor: RolloutProcessor, + max_concurrent_rollouts: int, + active_logger: DatasetLogger, + eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation + mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, + ): + self.rollout_processor = rollout_processor + self.max_concurrent_rollouts = max_concurrent_rollouts + self.active_logger = active_logger + self.eval_executor = eval_executor + self.mini_batch_data_buffer = mini_batch_data_buffer + + # Priority Queue: Stores RolloutTask + self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() + + self.num_runs = 0 + self.micro_batch_size = 0 + + async def schedule_dataset( + self, + dataset: List[EvaluationRow], + base_config: RolloutProcessorConfig, + ): + """ + Populates the queue with initial tasks (the first micro-batch for each sample). + """ + for i, row in enumerate(dataset): + # Calculate ranges for the first micro-batch + batch_start = 0 + # Ensure micro_batch_size is at least 1 to avoid infinite loop or stuck tasks + safe_batch_size = self.micro_batch_size if self.micro_batch_size > 0 else self.num_runs + batch_end = min(safe_batch_size, self.num_runs) + run_indices = list(range(batch_start, batch_end)) + + # Initial priority: Low (1), ordered by dataset index + priority = (1, i) + + task = RolloutTask( + priority=priority, + row=row, + run_indices=run_indices, + config=base_config, + row_index=i, + history=[] # Initial batch has no history + ) + self.queue.put_nowait(task) + + async def worker(self): + """ + Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any). + """ + while True: + try: + # Get a task from the priority queue + task: RolloutTask = await self.queue.get() + except asyncio.QueueEmpty: + break + + try: + await self._process_task(task) + except Exception as e: + print(f"Error processing task for row {task.row.input_metadata.row_id}: {e}") + finally: + self.queue.task_done() + + async def _process_task(self, task: RolloutTask): + """ + Executes a single micro-batch task. + """ + # 1. Prepare Config & Row for this micro-batch + current_batch_rows = [] + for run_idx in task.run_indices: + row_copy = task.row.model_copy(deep=True) + + row_copy.execution_metadata.run_id = generate_id() + row_copy.execution_metadata.rollout_id = generate_id() + + # Inject Speculation History + if task.history: + cp = row_copy.input_metadata.completion_params + # Ensure safe dict access + if not isinstance(cp, dict): + cp = {} + # Need to check and initialize nested dicts + extra_body = cp.get("extra_body") + if extra_body is None or not isinstance(extra_body, dict): + extra_body = {} + + extra_body["prediction"] = task.history + cp["extra_body"] = extra_body + row_copy.input_metadata.completion_params = cp + + current_batch_rows.append(row_copy) + self.active_logger.log(row_copy) + + # 2. Execute Rollout + batch_results: List[EvaluationRow] = [] + if task.run_indices: + representative_run_idx = task.run_indices[0] + + async for result_row in rollout_processor_with_retry( + self.rollout_processor, current_batch_rows, task.config, representative_run_idx + ): + batch_results.append(result_row) + + # 3. Evaluate and Collect History + current_batch_history_updates = [] + + for res in batch_results: + # Run Evaluation + eval_res = await self.eval_executor(res) + + # Depending on the execution mode, eval_executor might return a single row or a list + # For pointwise, it's a single row. For groupwise, it's a list. + # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back + # But to be safe and type-correct, we handle both. + + if isinstance(eval_res, list): + # Should not happen in pointwise mode which is typically used with this scheduler + # But if it does, we process each result + for r in eval_res: + if self.mini_batch_data_buffer: + await self.mini_batch_data_buffer.add_result(r) + + last_msg = r.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + if isinstance(content, list): + text_parts = [p["text"] for p in content if p["type"] == "text"] + current_batch_history_updates.append("".join(text_parts)) + else: + current_batch_history_updates.append(str(content)) + else: + current_batch_history_updates.append("") + else: + if self.mini_batch_data_buffer: + await self.mini_batch_data_buffer.add_result(eval_res) + + # Extract prediction for history + last_msg = eval_res.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + if isinstance(content, list): + text_parts = [p["text"] for p in content if p["type"] == "text"] + current_batch_history_updates.append("".join(text_parts)) + else: + current_batch_history_updates.append(str(content)) + else: + current_batch_history_updates.append("") # Empty string for failed turns + + # 4. Schedule Next Micro-batch (High Priority) + last_run_idx = task.run_indices[-1] + next_start = last_run_idx + 1 + + if next_start < self.num_runs: + next_end = min(next_start + self.micro_batch_size, self.num_runs) + next_indices = list(range(next_start, next_end)) + new_history = task.history + current_batch_history_updates + + # Priority 0 (High) to ensure we finish this sample ASAP + new_priority = (0, task.row_index) + + new_task = RolloutTask( + priority=new_priority, + row=task.row, + run_indices=next_indices, + config=task.config, + row_index=task.row_index, + history=new_history + ) + self.queue.put_nowait(new_task) + + async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_size: int, base_config: RolloutProcessorConfig): + self.num_runs = num_runs + self.micro_batch_size = micro_batch_size + + # 1. Schedule initial tasks + await self.schedule_dataset(dataset, base_config) + + # 2. Start Workers + workers = [asyncio.create_task(self.worker()) for _ in range(self.max_concurrent_rollouts)] + + # 3. Wait for completion + await self.queue.join() + + # 4. Cleanup + for w in workers: + w.cancel() + + # Ensure cancellation is complete + if workers: + await asyncio.gather(*workers, return_exceptions=True) + + # Return empty dict as we rely on side effects (streaming buffer) + return {} + +async def execute_priority_rollouts( + dataset: List[EvaluationRow], + num_runs: int, + micro_batch_size: int, + rollout_processor: RolloutProcessor, + config: RolloutProcessorConfig, + max_concurrent_rollouts: int, + active_logger: DatasetLogger, + eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], + mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, +): + scheduler = PriorityRolloutScheduler( + rollout_processor=rollout_processor, + max_concurrent_rollouts=max_concurrent_rollouts, + active_logger=active_logger, + eval_executor=eval_executor, + mini_batch_data_buffer=mini_batch_data_buffer + ) + return await scheduler.run(dataset, num_runs, micro_batch_size, config) From 2865b790a9cffc27cbb47861c1cdcaec15a553fd Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Tue, 2 Dec 2025 16:57:00 -0800 Subject: [PATCH 02/11] add priority rolluot scheduler --- eval_protocol/pytest/priority_scheduler.py | 95 ++++--- tests/test_priority_scheduler.py | 282 +++++++++++++++++++++ 2 files changed, 340 insertions(+), 37 deletions(-) create mode 100644 tests/test_priority_scheduler.py diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 19e5a827..ff9706f7 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -45,9 +45,11 @@ def __init__( active_logger: DatasetLogger, eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, + max_concurrent_evaluations: Optional[int] = None, ): self.rollout_processor = rollout_processor self.max_concurrent_rollouts = max_concurrent_rollouts + self.max_concurrent_evaluations = max_concurrent_evaluations self.active_logger = active_logger self.eval_executor = eval_executor self.mini_batch_data_buffer = mini_batch_data_buffer @@ -55,6 +57,10 @@ def __init__( # Priority Queue: Stores RolloutTask self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() + # Concurrency Control + self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts) + self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) if max_concurrent_evaluations else None + self.num_runs = 0 self.micro_batch_size = 0 @@ -140,31 +146,48 @@ async def _process_task(self, task: RolloutTask): if task.run_indices: representative_run_idx = task.run_indices[0] - async for result_row in rollout_processor_with_retry( - self.rollout_processor, current_batch_rows, task.config, representative_run_idx - ): - batch_results.append(result_row) + async with self.rollout_sem: + async for result_row in rollout_processor_with_retry( + self.rollout_processor, current_batch_rows, task.config, representative_run_idx + ): + batch_results.append(result_row) # 3. Evaluate and Collect History current_batch_history_updates = [] - for res in batch_results: - # Run Evaluation - eval_res = await self.eval_executor(res) - - # Depending on the execution mode, eval_executor might return a single row or a list - # For pointwise, it's a single row. For groupwise, it's a list. - # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back - # But to be safe and type-correct, we handle both. - - if isinstance(eval_res, list): - # Should not happen in pointwise mode which is typically used with this scheduler - # But if it does, we process each result - for r in eval_res: + async def _run_eval(): + for res in batch_results: + # Run Evaluation + eval_res = await self.eval_executor(res) + + # Depending on the execution mode, eval_executor might return a single row or a list + # For pointwise, it's a single row. For groupwise, it's a list. + # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back + # But to be safe and type-correct, we handle both. + + if isinstance(eval_res, list): + # Should not happen in pointwise mode which is typically used with this scheduler + # But if it does, we process each result + for r in eval_res: + if self.mini_batch_data_buffer: + await self.mini_batch_data_buffer.add_result(r) + + last_msg = r.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + if isinstance(content, list): + text_parts = [p["text"] for p in content if p["type"] == "text"] + current_batch_history_updates.append("".join(text_parts)) + else: + current_batch_history_updates.append(str(content)) + else: + current_batch_history_updates.append("") + else: if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(r) - - last_msg = r.last_assistant_message() + await self.mini_batch_data_buffer.add_result(eval_res) + + # Extract prediction for history + last_msg = eval_res.last_assistant_message() if last_msg and last_msg.content: content = last_msg.content if isinstance(content, list): @@ -173,22 +196,13 @@ async def _process_task(self, task: RolloutTask): else: current_batch_history_updates.append(str(content)) else: - current_batch_history_updates.append("") - else: - if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(eval_res) + current_batch_history_updates.append("") # Empty string for failed turns - # Extract prediction for history - last_msg = eval_res.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - if isinstance(content, list): - text_parts = [p["text"] for p in content if p["type"] == "text"] - current_batch_history_updates.append("".join(text_parts)) - else: - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") # Empty string for failed turns + if self.eval_sem: + async with self.eval_sem: + await _run_eval() + else: + await _run_eval() # 4. Schedule Next Micro-batch (High Priority) last_run_idx = task.run_indices[-1] @@ -220,7 +234,12 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz await self.schedule_dataset(dataset, base_config) # 2. Start Workers - workers = [asyncio.create_task(self.worker()) for _ in range(self.max_concurrent_rollouts)] + # If we have separate limits, we need enough workers to saturate both stages + num_workers = self.max_concurrent_rollouts + if self.max_concurrent_evaluations: + num_workers += self.max_concurrent_evaluations + + workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] # 3. Wait for completion await self.queue.join() @@ -246,12 +265,14 @@ async def execute_priority_rollouts( active_logger: DatasetLogger, eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, + max_concurrent_evaluations: Optional[int] = None, ): scheduler = PriorityRolloutScheduler( rollout_processor=rollout_processor, max_concurrent_rollouts=max_concurrent_rollouts, active_logger=active_logger, eval_executor=eval_executor, - mini_batch_data_buffer=mini_batch_data_buffer + mini_batch_data_buffer=mini_batch_data_buffer, + max_concurrent_evaluations=max_concurrent_evaluations ) return await scheduler.run(dataset, num_runs, micro_batch_size, config) diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py new file mode 100644 index 00000000..d4def976 --- /dev/null +++ b/tests/test_priority_scheduler.py @@ -0,0 +1,282 @@ +import pytest +import asyncio +import time +from unittest.mock import MagicMock, AsyncMock +from typing import List, Union + +from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata +from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask +from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger + +# Mock models +def create_mock_row(row_id: str = "test-row") -> EvaluationRow: + return EvaluationRow( + input_metadata=InputMetadata( + row_id=row_id, + completion_params={"model": "test-model"} + ), + execution_metadata=ExecutionMetadata() + ) + +@pytest.fixture +def mock_rollout_processor(): + processor = MagicMock() + # Mocking the rollout to be an async generator + async def mock_rollout_gen(rows, config, run_idx): + for row in rows: + # Simulate some work + yield row + processor.side_effect = mock_rollout_gen + return processor + +@pytest.fixture +def mock_logger(): + return MagicMock(spec=DatasetLogger) + +@pytest.fixture +def mock_eval_executor(): + return AsyncMock() + +@pytest.fixture +def base_config(): + return RolloutProcessorConfig( + completion_params={"model": "test-model"}, + mcp_config_path="test_config.yaml", + semaphore=asyncio.Semaphore(10), + steps=10 + ) + +@pytest.mark.asyncio +async def test_scheduler_basic_execution( + mock_logger, mock_eval_executor, base_config +): + """Test that the scheduler processes all rows and completes.""" + dataset = [create_mock_row(f"row-{i}") for i in range(5)] + num_runs = 2 + micro_batch_size = 1 + + # Mock rollout processor with delay + async def delayed_rollout(rows, config, run_idx): + await asyncio.sleep(0.01) + for row in rows: + yield row + + mock_processor = MagicMock() + mock_processor.side_effect = delayed_rollout # This is wrong usage for call, rollout_processor is passed as instance + # But wait, PriorityRolloutScheduler calls rollout_processor_with_retry which calls processor.process_batch or similar? + # Looking at code: rollout_processor_with_retry(self.rollout_processor, ...) + # rollout_processor_with_retry expects the processor instance. + + # Let's look at how rollout_processor_with_retry is implemented or usage. + # Assuming rollout_processor is an object with a method or it's a callable? + # In priority_scheduler.py: rollout_processor_with_retry(self.rollout_processor, ...) + + # Let's actually mock rollout_processor_with_retry since we want to test the scheduler logic, + # not the processor retry logic. + # But we can't easily mock the import inside the module without patching. + pass + +# We will rely on patching 'eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry' +from unittest.mock import patch + +@pytest.mark.asyncio +async def test_concurrency_control( + mock_logger, mock_eval_executor, base_config +): + """ + Verify that max_concurrent_rollouts and max_concurrent_evaluations are respected. + """ + dataset = [create_mock_row(f"row-{i}") for i in range(10)] + num_runs = 1 + micro_batch_size = 1 + + max_rollouts = 4 + max_evals = 2 + + active_rollouts = 0 + max_active_rollouts_seen = 0 + + active_evals = 0 + max_active_evals_seen = 0 + + rollout_lock = asyncio.Lock() + eval_lock = asyncio.Lock() + + async def mock_rollout_gen(processor, rows, config, run_idx): + nonlocal active_rollouts, max_active_rollouts_seen + async with rollout_lock: + active_rollouts += 1 + max_active_rollouts_seen = max(max_active_rollouts_seen, active_rollouts) + + # Simulate slow rollout + await asyncio.sleep(0.05) + + for row in rows: + yield row + + async with rollout_lock: + active_rollouts -= 1 + + async def mock_eval(row): + nonlocal active_evals, max_active_evals_seen + async with eval_lock: + active_evals += 1 + max_active_evals_seen = max(max_active_evals_seen, active_evals) + + # Simulate evaluation + await asyncio.sleep(0.05) + + async with eval_lock: + active_evals -= 1 + return row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + mock_eval_executor.side_effect = mock_eval + + # Mock processor instance (can be anything since we patched the wrapper) + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=max_rollouts, + active_logger=mock_logger, + eval_executor=mock_eval_executor, + max_concurrent_evaluations=max_evals + ) + + await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + + # Verify limits were respected + assert max_active_rollouts_seen <= max_rollouts, f"Rollout concurrency exceeded: {max_active_rollouts_seen} > {max_rollouts}" + assert max_active_evals_seen <= max_evals, f"Eval concurrency exceeded: {max_active_evals_seen} > {max_evals}" + + # Verify everything ran + # 10 rows * 1 run = 10 rollouts called + # 10 evaluations + assert mock_eval_executor.call_count == 10 + +@pytest.mark.asyncio +async def test_priority_scheduling( + mock_logger, mock_eval_executor, base_config +): + """ + Test that subsequent micro-batches are prioritized. + This is tricky to test deterministically with asyncio, but we can try to observe order + or ensure that a task that spawns new parts gets priority. + + We'll simulate a case where we have 2 samples, each needing 2 micro-batches. + We want to see if Sample 1 Batch 2 runs before Sample 2 Batch 1 is finished if possible, + but actually the scheduler puts Sample 1 Batch 2 with Priority 0 (High) and Sample 2 Batch 1 starts with Priority 1 (Low). + + If we limit concurrency to 1, we should see: + S1_B1 -> S1_B2 -> S2_B1 -> S2_B2 + + Wait, if concurrency is 1: + 1. Queue: [S1_B1 (Low), S2_B1 (Low)] + 2. Worker picks S1_B1. Queue: [S2_B1 (Low)] + 3. S1_B1 finishes. Puts S1_B2 (High). Queue: [S1_B2 (High), S2_B1 (Low)] + 4. Worker picks S1_B2. Queue: [S2_B1 (Low)] + 5. S1_B2 finishes. Queue: [S2_B1 (Low)] + 6. Worker picks S2_B1. ... + + So yes, strictly sequential per sample if concurrency=1. + """ + dataset = [create_mock_row(f"row-{i}") for i in range(2)] + num_runs = 2 + micro_batch_size = 1 + + execution_order = [] + + async def mock_rollout_gen(processor, rows, config, run_idx): + row_id = rows[0].input_metadata.row_id + execution_order.append(f"{row_id}_run_{run_idx}") + for row in rows: + yield row + + async def mock_eval(row): + return row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + mock_eval_executor.side_effect = mock_eval + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=1, # Force serial execution to test priority + active_logger=mock_logger, + eval_executor=mock_eval_executor, + ) + + await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + + # Expected order: row-0_run_0, row-0_run_1, row-1_run_0, row-1_run_1 + # Or at least row-0_run_1 should come before row-1_run_0 finishes if parallel? + # With concurrency 1, it should be strictly: + # row-0 run 0 + # row-0 run 1 (high priority injected) + # row-1 run 0 + # row-1 run 1 + + expected = [ + "row-0_run_0", + "row-0_run_1", + "row-1_run_0", + "row-1_run_1" + ] + + assert execution_order == expected + +@pytest.mark.asyncio +async def test_worker_scaling( + mock_logger, mock_eval_executor, base_config +): + """ + Test that the number of workers scales with the sum of limits. + """ + dataset = [create_mock_row("row-0")] + max_rollouts = 5 + max_evals = 3 + expected_workers = max_rollouts + max_evals + + worker_start_count = 0 + + class InstrumentedScheduler(PriorityRolloutScheduler): + async def worker(self): + nonlocal worker_start_count + worker_start_count += 1 + try: + await self.queue.get() + self.queue.task_done() + except asyncio.CancelledError: + pass + except Exception: + pass + + async def schedule_dataset(self, *args): + # Put enough items to ensure all workers wake up and grab one + for i in range(expected_workers): + task = RolloutTask( + priority=(1, i), + row=dataset[0], + run_indices=[], + config=base_config, + row_index=0, + history=[] + ) + await self.queue.put(task) + + processor_instance = MagicMock() + scheduler = InstrumentedScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=max_rollouts, + active_logger=mock_logger, + eval_executor=mock_eval_executor, + max_concurrent_evaluations=max_evals + ) + + await scheduler.run(dataset, 1, 1, base_config) + + assert worker_start_count == expected_workers + + From 37e0210f54edddbf9d687d55e5fa4d6365241fdb Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 4 Dec 2025 11:34:23 -0800 Subject: [PATCH 03/11] groupwise --- eval_protocol/pytest/priority_scheduler.py | 133 ++++++++++++++------- tests/test_priority_scheduler.py | 50 ++++++++ 2 files changed, 142 insertions(+), 41 deletions(-) diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index ff9706f7..aa0305ae 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -44,15 +44,17 @@ def __init__( max_concurrent_rollouts: int, active_logger: DatasetLogger, eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation - mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, + output_buffer: Optional[MiniBatchDataBuffer] = None, max_concurrent_evaluations: Optional[int] = None, + mode: str = "pointwise", ): self.rollout_processor = rollout_processor self.max_concurrent_rollouts = max_concurrent_rollouts self.max_concurrent_evaluations = max_concurrent_evaluations self.active_logger = active_logger self.eval_executor = eval_executor - self.mini_batch_data_buffer = mini_batch_data_buffer + self.output_buffer = output_buffer + self.mode = mode # Priority Queue: Stores RolloutTask self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() @@ -61,6 +63,10 @@ def __init__( self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts) self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) if max_concurrent_evaluations else None + # Results storage + self.results: List[EvaluationRow] = [] # for backward compatibility reason, we save all results here to return + self.groups_buffer: Dict[int, List[EvaluationRow]] = defaultdict(list) # buffer for group results. only flush to output buffer when a whole group is ready + self.num_runs = 0 self.micro_batch_size = 0 @@ -155,24 +161,85 @@ async def _process_task(self, task: RolloutTask): # 3. Evaluate and Collect History current_batch_history_updates = [] - async def _run_eval(): - for res in batch_results: - # Run Evaluation - eval_res = await self.eval_executor(res) - - # Depending on the execution mode, eval_executor might return a single row or a list - # For pointwise, it's a single row. For groupwise, it's a list. - # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back - # But to be safe and type-correct, we handle both. + if self.mode == "groupwise": + # Collect all results from this batch + for res in batch_results: + self.groupwise_buffer[task.row_index].append(res) - if isinstance(eval_res, list): - # Should not happen in pointwise mode which is typically used with this scheduler - # But if it does, we process each result - for r in eval_res: + # Update history from rollout result (assuming eval doesn't change content needed for history) + last_msg = res.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + if isinstance(content, list): + text_parts = [p["text"] for p in content if p["type"] == "text"] + current_batch_history_updates.append("".join(text_parts)) + else: + current_batch_history_updates.append(str(content)) + else: + current_batch_history_updates.append("") + + # Check if this is the last batch for this sample + last_run_idx = task.run_indices[-1] + if last_run_idx + 1 >= self.num_runs: + # Last batch: Execute Groupwise Evaluation + full_group = self.groupwise_buffer[task.row_index] + + async def _run_group_eval(): + eval_res = await self.eval_executor(full_group) + # Handle result (could be list or single row wrapping list?) + # Usually groupwise returns list of scored rows + if isinstance(eval_res, list): + self.results.extend(eval_res) + if self.mini_batch_data_buffer: + # Push the whole group at once if possible, or iterate + for r in eval_res: + await self.mini_batch_data_buffer.add_result(r) + else: + self.results.append(eval_res) + if self.mini_batch_data_buffer: + await self.mini_batch_data_buffer.add_result(eval_res) + + if self.eval_sem: + async with self.eval_sem: + await _run_group_eval() + else: + await _run_group_eval() + + # Clear buffer to free memory + del self.groupwise_buffer[task.row_index] + + else: + # Pointwise: Process each result individually + async def _run_eval(): + for res in batch_results: + # Run Evaluation + eval_res = await self.eval_executor(res) + + if isinstance(eval_res, list): + # Should not happen in pointwise mode which is typically used with this scheduler + # But if it does, we process each result + self.results.extend(eval_res) + for r in eval_res: + if self.mini_batch_data_buffer: + await self.mini_batch_data_buffer.add_result(r) + + last_msg = r.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + if isinstance(content, list): + text_parts = [p["text"] for p in content if p["type"] == "text"] + current_batch_history_updates.append("".join(text_parts)) + else: + current_batch_history_updates.append(str(content)) + else: + current_batch_history_updates.append("") + else: + self.results.append(eval_res) if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(r) - - last_msg = r.last_assistant_message() + await self.mini_batch_data_buffer.add_result(eval_res) + + # Extract prediction for history + last_msg = eval_res.last_assistant_message() if last_msg and last_msg.content: content = last_msg.content if isinstance(content, list): @@ -181,28 +248,13 @@ async def _run_eval(): else: current_batch_history_updates.append(str(content)) else: - current_batch_history_updates.append("") - else: - if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(eval_res) + current_batch_history_updates.append("") # Empty string for failed turns - # Extract prediction for history - last_msg = eval_res.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - if isinstance(content, list): - text_parts = [p["text"] for p in content if p["type"] == "text"] - current_batch_history_updates.append("".join(text_parts)) - else: - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") # Empty string for failed turns - - if self.eval_sem: - async with self.eval_sem: + if self.eval_sem: + async with self.eval_sem: + await _run_eval() + else: await _run_eval() - else: - await _run_eval() # 4. Schedule Next Micro-batch (High Priority) last_run_idx = task.run_indices[-1] @@ -248,12 +300,11 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz for w in workers: w.cancel() - # Ensure cancellation is complete if workers: await asyncio.gather(*workers, return_exceptions=True) - # Return empty dict as we rely on side effects (streaming buffer) - return {} + # Return collected results + return self.results async def execute_priority_rollouts( dataset: List[EvaluationRow], diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py index d4def976..b5778a68 100644 --- a/tests/test_priority_scheduler.py +++ b/tests/test_priority_scheduler.py @@ -279,4 +279,54 @@ async def schedule_dataset(self, *args): assert worker_start_count == expected_workers +@pytest.mark.asyncio +async def test_groupwise_mode( + mock_logger, mock_eval_executor, base_config +): + """ + Test that groupwise mode collects all runs before evaluating. + """ + dataset = [create_mock_row("row-0")] + num_runs = 4 + micro_batch_size = 2 + + # We expect 2 batches of 2 runs each. + # Batch 1 (Runs 0,1): Should buffer and update history, NOT call eval. + # Batch 2 (Runs 2,3): Should buffer, update history, AND call eval with all 4 runs. + + eval_calls = [] + + async def mock_eval(rows): + eval_calls.append(rows) + return rows # Pass through + + async def mock_rollout_gen(processor, rows, config, run_idx): + for row in rows: + yield row + + mock_eval_executor.side_effect = mock_eval + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=1, + active_logger=mock_logger, + eval_executor=mock_eval_executor, + mode="groupwise" + ) + + results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + + # Verify evaluation was called EXACTLY ONCE + assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}" + + # Verify it was called with ALL 4 rows + evaluated_rows = eval_calls[0] + assert len(evaluated_rows) == 4, f"Expected 4 rows in group eval, got {len(evaluated_rows)}" + + # Verify results contains all 4 rows + assert len(results) == 4 + From ff329d8113c31ffcc3f8a707cf4a0edaecdf8c41 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 4 Dec 2025 23:20:51 -0800 Subject: [PATCH 04/11] add --- eval_protocol/pytest/evaluation_test.py | 496 +++++++++++---------- eval_protocol/pytest/priority_scheduler.py | 274 ++++++------ eval_protocol/pytest/validate_signature.py | 2 - tests/pytest/test_rollout_scheduler.py | 50 +++ tests/test_priority_scheduler.py | 128 +++--- 5 files changed, 527 insertions(+), 423 deletions(-) create mode 100644 tests/pytest/test_rollout_scheduler.py diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index db4a17a1..5dc5ca5e 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -26,6 +26,7 @@ from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper from eval_protocol.pytest.evaluation_test_postprocess import postprocess from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling +from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts from eval_protocol.pytest.generate_parameter_combinations import ( ParameterizedTestKwargs, generate_parameter_combinations, @@ -69,7 +70,7 @@ from eval_protocol.log_utils.init import init_external_logging_from_env from eval_protocol.log_utils.rollout_context import rollout_logging_context from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab - +from eval_protocol.pytest.buffer import MiniBatchDataBuffer from ..common_utils import load_jsonl @@ -402,59 +403,192 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo rollout_processor.setup() - async def execute_run(run_idx: int, config: RolloutProcessorConfig): - nonlocal all_results - - # Regenerate outputs each run by deep-copying the pristine dataset - # so model responses are not reused across runs. - run_id = generate_id() - fresh_dataset = [r.model_copy(deep=True) for r in data] - - # apply new run_id to fresh_dataset - for row in fresh_dataset: - row.execution_metadata.run_id = run_id - - # generate new rollout_id for each row - for row in fresh_dataset: - row.execution_metadata.rollout_id = generate_id() + use_priority_scheduler = ( + ( + os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1" + and not isinstance(rollout_processor, MCPGymRolloutProcessor) + ) + ) - # log the fresh_dataset - for row in fresh_dataset: - active_logger.log(row) - processed_rows_in_run.append(row) + if use_priority_scheduler: + print("Using priority scheduler") + minibatch_output_size = os.environ.get("EP_MINI_BATCH_OUTPUT_SIZE", None) + output_dir = os.environ.get("EP_OUTPUT_DIR", None) + if minibatch_output_size and output_dir: + output_buffer = MiniBatchDataBuffer(num_runs=num_runs, minibatch_size=int(minibatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl")) + else: + output_buffer = None + priority_results = await execute_priority_rollouts( + dataset=data, + num_runs=num_runs, + micro_batch_size=int(os.environ.get("EP_MICRO_BATCH_SIZE", "1")), + rollout_processor=rollout_processor, + config=config, + max_concurrent_rollouts=max_concurrent_rollouts, + active_logger=active_logger, + eval_executor=test_func, + max_concurrent_evaluations=max_concurrent_evaluations, + mode=mode, + mini_batch_data_buffer=output_buffer, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) + + for res in priority_results: + run_idx = (res.execution_metadata.extra or {}).get("run_index", 0) + if run_idx < len(all_results): + all_results[run_idx].append(res) + + processed_rows_in_run.append(res) - # prepare parallel eval helper function - semaphore = asyncio.Semaphore(max_concurrent_evaluations) + else: + async def execute_run(run_idx: int, config: RolloutProcessorConfig): + nonlocal all_results + + # Regenerate outputs each run by deep-copying the pristine dataset + # so model responses are not reused across runs. + run_id = generate_id() + fresh_dataset = [r.model_copy(deep=True) for r in data] + + # apply new run_id to fresh_dataset + for row in fresh_dataset: + row.execution_metadata.run_id = run_id + + # generate new rollout_id for each row + for row in fresh_dataset: + row.execution_metadata.rollout_id = generate_id() + + # log the fresh_dataset + for row in fresh_dataset: + active_logger.log(row) + processed_rows_in_run.append(row) + + # prepare parallel eval helper function + semaphore = asyncio.Semaphore(max_concurrent_evaluations) + + async def _execute_pointwise_eval_with_semaphore( + row: EvaluationRow, + ) -> EvaluationRow: + async with semaphore: + evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} + async with rollout_logging_context( + row.execution_metadata.rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + ): + result = await execute_pytest_with_exception_handling( + test_func=test_func, + evaluation_test_kwargs=evaluation_test_kwargs, + processed_row=row, + ) + if not isinstance(result, EvaluationRow): + raise ValueError( + f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + ) + return result - async def _execute_pointwise_eval_with_semaphore( - row: EvaluationRow, - ) -> EvaluationRow: - async with semaphore: - evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} - async with rollout_logging_context( - row.execution_metadata.rollout_id or "", - experiment_id=experiment_id, - run_id=run_id, + async def _execute_groupwise_eval_with_semaphore( + rows: list[EvaluationRow], + ) -> list[EvaluationRow]: + async with semaphore: + evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} + primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None + group_rollout_ids = [ + r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + ] + async with rollout_logging_context( + primary_rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + rollout_ids=group_rollout_ids or None, + ): + results = await execute_pytest_with_exception_handling( + test_func=test_func, + evaluation_test_kwargs=evaluation_test_kwargs, + processed_dataset=rows, + ) + if not isinstance(results, list): + raise ValueError( + f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + ) + return results + + if mode == "pointwise": + # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution + pointwise_tasks: list[asyncio.Task[EvaluationRow]] = [] + # Use wrapper that handles retry logic internally + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, run_idx ): - result = await execute_pytest_with_exception_handling( - test_func=test_func, - evaluation_test_kwargs=evaluation_test_kwargs, - processed_row=row, + pointwise_tasks.append( + asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row)) ) - if not isinstance(result, EvaluationRow): - raise ValueError( - f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + + # Run evaluation tasks with progress bar + results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx) + + all_results[run_idx] = results + elif mode == "groupwise": + # rollout all the completion_params for the same row at once, and then send the output to the test_func + row_groups = defaultdict(list) # key: row_id, value: list of rollout_result + tasks: list[asyncio.Task[list[EvaluationRow]]] = [] + # completion_groups = [] + for idx, cp in enumerate(original_completion_params): + config = RolloutProcessorConfig( + completion_params=cp if cp is not None else {}, + mcp_config_path=mcp_config_path or "", + server_script_path=server_script_path, + steps=steps, + logger=active_logger, + semaphore=shared_semaphore, + kwargs=rollout_processor_kwargs or {}, ) - return result - - async def _execute_groupwise_eval_with_semaphore( - rows: list[EvaluationRow], - ) -> list[EvaluationRow]: - async with semaphore: - evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} - primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None + lst = [] + + async def _collect_result(config, lst): + result = [] + async for row in rollout_processor_with_retry( + rollout_processor, lst, config, run_idx + ): # pyright: ignore[reportUnknownArgumentType] + result.append(row) + return result + + for ori_row in fresh_dataset: + copied_row = ori_row.model_copy(deep=True) + # overwrite the rollout_id to the index of the completion_params + copied_row.execution_metadata.rollout_id = ( + str(ori_row.execution_metadata.rollout_id) + "_" + str(idx) + ) + copied_row.input_metadata.completion_params = cp if cp is not None else {} + lst.append(copied_row) + tasks.append(asyncio.create_task(_collect_result(config, lst))) + rollout_results = await asyncio.gather(*tasks) + for result in rollout_results: + for row in result: + row_groups[row.input_metadata.row_id].append(row) + tasks = [] + for _, rows in row_groups.items(): + tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) + results = [] + for task in tasks: + res = await task + results.extend(res) + all_results[run_idx] = results + else: + # Batch mode: collect all results first, then evaluate (no pipelining) + input_dataset = [] + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, run_idx + ): + input_dataset.append(row) + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func + primary_rollout_id = ( + input_dataset[0].execution_metadata.rollout_id if input_dataset else None + ) group_rollout_ids = [ - r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + r.execution_metadata.rollout_id + for r in input_dataset + if r.execution_metadata.rollout_id ] async with rollout_logging_context( primary_rollout_id or "", @@ -464,205 +598,113 @@ async def _execute_groupwise_eval_with_semaphore( ): results = await execute_pytest_with_exception_handling( test_func=test_func, - evaluation_test_kwargs=evaluation_test_kwargs, - processed_dataset=rows, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + processed_dataset=input_dataset, ) - if not isinstance(results, list): + if ( + results is None + or not isinstance(results, list) + or not all(isinstance(r, EvaluationRow) for r in results) + ): raise ValueError( f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." ) - return results - - if mode == "pointwise": - # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution - pointwise_tasks: list[asyncio.Task[EvaluationRow]] = [] - # Use wrapper that handles retry logic internally - async for row in rollout_processor_with_retry( - rollout_processor, fresh_dataset, config, run_idx - ): - pointwise_tasks.append( - asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row)) - ) - - # Run evaluation tasks with progress bar - results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx) - - all_results[run_idx] = results - elif mode == "groupwise": - # rollout all the completion_params for the same row at once, and then send the output to the test_func - row_groups = defaultdict(list) # key: row_id, value: list of rollout_result - tasks: list[asyncio.Task[list[EvaluationRow]]] = [] - # completion_groups = [] - for idx, cp in enumerate(original_completion_params): - config = RolloutProcessorConfig( - completion_params=cp if cp is not None else {}, - mcp_config_path=mcp_config_path or "", - server_script_path=server_script_path, - steps=steps, - logger=active_logger, - semaphore=shared_semaphore, - kwargs=rollout_processor_kwargs or {}, - ) - lst = [] - - async def _collect_result(config, lst): - result = [] - async for row in rollout_processor_with_retry( - rollout_processor, lst, config, run_idx - ): # pyright: ignore[reportUnknownArgumentType] - result.append(row) - return result - - for ori_row in fresh_dataset: - copied_row = ori_row.model_copy(deep=True) - # overwrite the rollout_id to the index of the completion_params - copied_row.execution_metadata.rollout_id = ( - str(ori_row.execution_metadata.rollout_id) + "_" + str(idx) + if not results: + raise ValueError( + f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test." ) - copied_row.input_metadata.completion_params = cp if cp is not None else {} - lst.append(copied_row) - tasks.append(asyncio.create_task(_collect_result(config, lst))) - rollout_results = await asyncio.gather(*tasks) - for result in rollout_results: - for row in result: - row_groups[row.input_metadata.row_id].append(row) - tasks = [] - for _, rows in row_groups.items(): - tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) - results = [] - for task in tasks: - res = await task - results.extend(res) - all_results[run_idx] = results + all_results[run_idx] = results + + for r in results: + add_cost_metrics(r) + if r.eval_metadata is not None: + if r.rollout_status.is_error(): + r.eval_metadata.status = Status.error( + r.rollout_status.message, r.rollout_status.details + ) + elif not ( + r.eval_metadata.status and r.eval_metadata.status.code != Status.Code.RUNNING + ): + # if the eval_metadata status code has not been set to something else, consider it as finished + r.eval_metadata.status = Status.eval_finished() + # Optional debug print for assistant/tool sequence + if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": + try: + preview = [ + { + "role": m.role, + "len": len(m.content or "") if isinstance(m.content, str) else None, + "tool_calls": len(m.tool_calls or []) + if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) + else 0, + "tool_call_id": getattr(m, "tool_call_id", None), + "name": getattr(m, "name", None), + } + for m in r.messages + ] + print("[EP-Log] Row messages:", preview) + except Exception: + pass + active_logger.log(r) + + # if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs + # else, we execute runs in parallel + if isinstance(rollout_processor, MCPGymRolloutProcessor): + # For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts + for run_idx in range(num_runs): + task = asyncio.create_task(execute_run(run_idx, config)) + await task else: - # Batch mode: collect all results first, then evaluate (no pipelining) - input_dataset = [] - async for row in rollout_processor_with_retry( - rollout_processor, fresh_dataset, config, run_idx - ): - input_dataset.append(row) - # NOTE: we will still evaluate errored rows (give users control over this) - # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func - primary_rollout_id = ( - input_dataset[0].execution_metadata.rollout_id if input_dataset else None - ) - group_rollout_ids = [ - r.execution_metadata.rollout_id - for r in input_dataset - if r.execution_metadata.rollout_id + # For other processors, create all tasks at once and run in parallel + # Concurrency is now controlled by the shared semaphore in each rollout processor + await run_tasks_with_run_progress(execute_run, num_runs, config) + + experiment_duration_seconds = time.perf_counter() - experiment_start_time + + # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them + # rollout_id is used to differentiate the result from different completion_params + if mode == "groupwise": + results_by_group = [ + [[] for _ in range(num_runs)] for _ in range(len(original_completion_params)) ] - async with rollout_logging_context( - primary_rollout_id or "", - experiment_id=experiment_id, - run_id=run_id, - rollout_ids=group_rollout_ids or None, - ): - results = await execute_pytest_with_exception_handling( - test_func=test_func, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, - processed_dataset=input_dataset, - ) - if ( - results is None - or not isinstance(results, list) - or not all(isinstance(r, EvaluationRow) for r in results) - ): - raise ValueError( - f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + for i_run, result in enumerate(all_results): + for r in result: + completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess] + results_by_group[completion_param_idx][i_run].append(r) + for rollout_id, result in enumerate(results_by_group): + postprocess( + result, + aggregation_method, + passed_threshold, + active_logger, + mode, + original_completion_params[rollout_id], # pyright: ignore[reportArgumentType] + test_func.__name__, + num_runs, + experiment_duration_seconds, ) - if not results: - raise ValueError( - f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test." - ) - all_results[run_idx] = results - - for r in results: - add_cost_metrics(r) - if r.eval_metadata is not None: - if r.rollout_status.is_error(): - r.eval_metadata.status = Status.error( - r.rollout_status.message, r.rollout_status.details - ) - elif not ( - r.eval_metadata.status and r.eval_metadata.status.code != Status.Code.RUNNING - ): - # if the eval_metadata status code has not been set to something else, consider it as finished - r.eval_metadata.status = Status.eval_finished() - # Optional debug print for assistant/tool sequence - if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": - try: - preview = [ - { - "role": m.role, - "len": len(m.content or "") if isinstance(m.content, str) else None, - "tool_calls": len(m.tool_calls or []) - if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) - else 0, - "tool_call_id": getattr(m, "tool_call_id", None), - "name": getattr(m, "name", None), - } - for m in r.messages - ] - print("[EP-Log] Row messages:", preview) - except Exception: - pass - active_logger.log(r) - - # if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs - # else, we execute runs in parallel - if isinstance(rollout_processor, MCPGymRolloutProcessor): - # For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts - for run_idx in range(num_runs): - task = asyncio.create_task(execute_run(run_idx, config)) - await task - else: - # For other processors, create all tasks at once and run in parallel - # Concurrency is now controlled by the shared semaphore in each rollout processor - await run_tasks_with_run_progress(execute_run, num_runs, config) - - experiment_duration_seconds = time.perf_counter() - experiment_start_time - - if not all(r.evaluation_result is not None for run_results in all_results for r in run_results): - raise AssertionError( - "Some EvaluationRow instances are missing evaluation_result. " - "Your @evaluation_test function must set `row.evaluation_result`" - ) - - # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them - # rollout_id is used to differentiate the result from different completion_params - if mode == "groupwise": - results_by_group = [ - [[] for _ in range(num_runs)] for _ in range(len(original_completion_params)) - ] - for i_run, result in enumerate(all_results): - for r in result: - completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess] - results_by_group[completion_param_idx][i_run].append(r) - for rollout_id, result in enumerate(results_by_group): + else: postprocess( - result, + all_results, aggregation_method, passed_threshold, active_logger, mode, - original_completion_params[rollout_id], # pyright: ignore[reportArgumentType] + completion_params, # pyright: ignore[reportArgumentType] test_func.__name__, num_runs, experiment_duration_seconds, ) - else: - postprocess( - all_results, - aggregation_method, - passed_threshold, - active_logger, - mode, - completion_params, # pyright: ignore[reportArgumentType] - test_func.__name__, - num_runs, - experiment_duration_seconds, + + + + if not all(r.evaluation_result is not None for run_results in all_results for r in run_results): + raise AssertionError( + "Some EvaluationRow instances are missing evaluation_result. " + "Your @evaluation_test function must set `row.evaluation_result`" ) + except AssertionError: _log_eval_error( Status.eval_finished(), diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index aa0305ae..f97caff9 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -5,12 +5,16 @@ from typing import Any, Callable, List, Dict, Optional, Union, Awaitable from eval_protocol.models import EvaluationRow, Status -from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.types import RolloutProcessorConfig, TestFunction from eval_protocol.pytest.rollout_processor import RolloutProcessor -from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry +from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry, add_cost_metrics from eval_protocol.pytest.buffer import MiniBatchDataBuffer from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.human_id import generate_id +from eval_protocol.log_utils.rollout_context import rollout_logging_context +from eval_protocol.pytest.execution import execute_pytest_with_exception_handling + +ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1" @dataclass(order=True) class RolloutTask: @@ -43,17 +47,20 @@ def __init__( rollout_processor: RolloutProcessor, max_concurrent_rollouts: int, active_logger: DatasetLogger, - eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation + max_concurrent_evaluations: int, + eval_executor: TestFunction, # Callback to run evaluation output_buffer: Optional[MiniBatchDataBuffer] = None, - max_concurrent_evaluations: Optional[int] = None, + rollout_n: int = 0, mode: str = "pointwise", + in_group_microbatch_size: int = 0, # for one sample, how many runs to execute at the same time + evaluation_test_kwargs: Dict[str, Any] = {}, ): self.rollout_processor = rollout_processor self.max_concurrent_rollouts = max_concurrent_rollouts self.max_concurrent_evaluations = max_concurrent_evaluations self.active_logger = active_logger self.eval_executor = eval_executor - self.output_buffer = output_buffer + self.output_buffer = output_buffer self.mode = mode # Priority Queue: Stores RolloutTask @@ -61,14 +68,17 @@ def __init__( # Concurrency Control self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts) - self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) if max_concurrent_evaluations else None + self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) # Results storage self.results: List[EvaluationRow] = [] # for backward compatibility reason, we save all results here to return self.groups_buffer: Dict[int, List[EvaluationRow]] = defaultdict(list) # buffer for group results. only flush to output buffer when a whole group is ready - - self.num_runs = 0 - self.micro_batch_size = 0 + + self.background_tasks = set() # run evaluations in the background asynchronously + + self.rollout_n = rollout_n + self.in_group_microbatch_size = in_group_microbatch_size if in_group_microbatch_size > 0 else rollout_n + self.evaluation_test_kwargs = evaluation_test_kwargs async def schedule_dataset( self, @@ -79,11 +89,9 @@ async def schedule_dataset( Populates the queue with initial tasks (the first micro-batch for each sample). """ for i, row in enumerate(dataset): - # Calculate ranges for the first micro-batch + # Calculate ranges for the first in-group minibatch batch_start = 0 - # Ensure micro_batch_size is at least 1 to avoid infinite loop or stuck tasks - safe_batch_size = self.micro_batch_size if self.micro_batch_size > 0 else self.num_runs - batch_end = min(safe_batch_size, self.num_runs) + batch_end = min(self.in_group_microbatch_size, self.rollout_n) run_indices = list(range(batch_start, batch_end)) # Initial priority: Low (1), ordered by dataset index @@ -121,6 +129,48 @@ async def _process_task(self, task: RolloutTask): """ Executes a single micro-batch task. """ + async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): + """Background evaluation task.""" + rollout_id = rows_to_eval[0].execution_metadata.rollout_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.rollout_id + experiment_id = rows_to_eval[0].execution_metadata.experiment_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.experiment_id + run_id = rows_to_eval[0].execution_metadata.run_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.run_id + eval_res = None + + async with self.eval_sem: + async with rollout_logging_context( + rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + ): + if isinstance(rows_to_eval, list): + eval_res = await execute_pytest_with_exception_handling( + test_func=self.eval_executor, + evaluation_test_kwargs=self.evaluation_test_kwargs, + processed_dataset=rows_to_eval, + ) + else: + eval_res = await execute_pytest_with_exception_handling( + test_func=self.eval_executor, + evaluation_test_kwargs=self.evaluation_test_kwargs, + processed_row=rows_to_eval, + ) + + # push result to the output buffer + if self.output_buffer: + if isinstance(eval_res, list): + for row in eval_res: + self._post_process_result(row) + await self.output_buffer.add_result(row) + else: + self._post_process_result(eval_res) + await self.output_buffer.add_result(eval_res) + + if isinstance(eval_res, list): + self.results.extend(eval_res) + else: + self.results.append(eval_res) + return eval_res + # 1. Prepare Config & Row for this micro-batch current_batch_rows = [] for run_idx in task.run_indices: @@ -128,10 +178,14 @@ async def _process_task(self, task: RolloutTask): row_copy.execution_metadata.run_id = generate_id() row_copy.execution_metadata.rollout_id = generate_id() + if row_copy.execution_metadata.extra is None: + row_copy.execution_metadata.extra = {} + row_copy.execution_metadata.extra["run_index"] = run_idx # Inject Speculation History - if task.history: + if ENABLE_SPECULATION and task.history: cp = row_copy.input_metadata.completion_params + max_tokens = cp.get("max_tokens", 2048) # Ensure safe dict access if not isinstance(cp, dict): cp = {} @@ -139,129 +193,57 @@ async def _process_task(self, task: RolloutTask): extra_body = cp.get("extra_body") if extra_body is None or not isinstance(extra_body, dict): extra_body = {} - - extra_body["prediction"] = task.history + # for speculation, see + # https://docs.fireworks.ai/guides/predicted-outputs + # https://platform.openai.com/docs/guides/predicted-outputs?lang=python + extra_body["prediction"] = {"type": "content", "content": " ".join(task.history)[:max_tokens]} cp["extra_body"] = extra_body row_copy.input_metadata.completion_params = cp - current_batch_rows.append(row_copy) + current_batch_rows.append((run_idx, row_copy)) self.active_logger.log(row_copy) + # 2. Execute Rollout batch_results: List[EvaluationRow] = [] - if task.run_indices: - representative_run_idx = task.run_indices[0] - - async with self.rollout_sem: + if current_batch_rows: + for idx, row in current_batch_rows: async for result_row in rollout_processor_with_retry( - self.rollout_processor, current_batch_rows, task.config, representative_run_idx + self.rollout_processor, [row], task.config, idx ): batch_results.append(result_row) + # in pointwise, we start evaluation immediately + if self.mode == "pointwise": + t = asyncio.create_task(_run_eval(result_row)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) # 3. Evaluate and Collect History current_batch_history_updates = [] - - if self.mode == "groupwise": - # Collect all results from this batch - for res in batch_results: - self.groupwise_buffer[task.row_index].append(res) - - # Update history from rollout result (assuming eval doesn't change content needed for history) - last_msg = res.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - if isinstance(content, list): - text_parts = [p["text"] for p in content if p["type"] == "text"] - current_batch_history_updates.append("".join(text_parts)) - else: - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") - - # Check if this is the last batch for this sample - last_run_idx = task.run_indices[-1] - if last_run_idx + 1 >= self.num_runs: - # Last batch: Execute Groupwise Evaluation - full_group = self.groupwise_buffer[task.row_index] - - async def _run_group_eval(): - eval_res = await self.eval_executor(full_group) - # Handle result (could be list or single row wrapping list?) - # Usually groupwise returns list of scored rows - if isinstance(eval_res, list): - self.results.extend(eval_res) - if self.mini_batch_data_buffer: - # Push the whole group at once if possible, or iterate - for r in eval_res: - await self.mini_batch_data_buffer.add_result(r) - else: - self.results.append(eval_res) - if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(eval_res) - - if self.eval_sem: - async with self.eval_sem: - await _run_group_eval() - else: - await _run_group_eval() - - # Clear buffer to free memory - del self.groupwise_buffer[task.row_index] - - else: - # Pointwise: Process each result individually - async def _run_eval(): - for res in batch_results: - # Run Evaluation - eval_res = await self.eval_executor(res) - - if isinstance(eval_res, list): - # Should not happen in pointwise mode which is typically used with this scheduler - # But if it does, we process each result - self.results.extend(eval_res) - for r in eval_res: - if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(r) - - last_msg = r.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - if isinstance(content, list): - text_parts = [p["text"] for p in content if p["type"] == "text"] - current_batch_history_updates.append("".join(text_parts)) - else: - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") - else: - self.results.append(eval_res) - if self.mini_batch_data_buffer: - await self.mini_batch_data_buffer.add_result(eval_res) - - # Extract prediction for history - last_msg = eval_res.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - if isinstance(content, list): - text_parts = [p["text"] for p in content if p["type"] == "text"] - current_batch_history_updates.append("".join(text_parts)) - else: - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") # Empty string for failed turns - - if self.eval_sem: - async with self.eval_sem: - await _run_eval() + # Extract history from rollout results (assuming eval doesn't change content needed for history) + for res in batch_results: + last_msg = res.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + current_batch_history_updates.append(str(content)) else: - await _run_eval() + current_batch_history_updates.append("") + + # in groupwise, we send all rows to evaluator in one go when the whole group is complete + if self.mode == "groupwise": + self.groups_buffer[task.row_index].extend(batch_results) + if len(self.groups_buffer[task.row_index]) >= self.rollout_n: + full_group = self.groups_buffer.pop(task.row_index) + t = asyncio.create_task(_run_eval(full_group)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) # 4. Schedule Next Micro-batch (High Priority) - last_run_idx = task.run_indices[-1] + last_run_idx = task.run_indices[-1] if task.run_indices else -1 next_start = last_run_idx + 1 - if next_start < self.num_runs: - next_end = min(next_start + self.micro_batch_size, self.num_runs) + if next_start < self.rollout_n: + next_end = min(next_start + self.in_group_microbatch_size, self.rollout_n) next_indices = list(range(next_start, next_end)) new_history = task.history + current_batch_history_updates @@ -278,6 +260,40 @@ async def _run_eval(): ) self.queue.put_nowait(new_task) + def _post_process_result(self, res: EvaluationRow): + """ + Process evaluation result: update cost metrics, status, and log. + """ + add_cost_metrics(res) + if res.eval_metadata is not None: + if res.rollout_status.is_error(): + res.eval_metadata.status = Status.error( + res.rollout_status.message, res.rollout_status.details + ) + elif not ( + res.eval_metadata.status and res.eval_metadata.status.code != Status.Code.RUNNING + ): + res.eval_metadata.status = Status.eval_finished() + + if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": + try: + preview = [ + { + "role": m.role, + "len": len(m.content or "") if isinstance(m.content, str) else None, + "tool_calls": len(m.tool_calls or []) + if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) + else 0, + "tool_call_id": getattr(m, "tool_call_id", None), + "name": getattr(m, "name", None), + } + for m in res.messages + ] + print("[EP-Log] Row messages:", preview) + except Exception: + pass + self.active_logger.log(res) + async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_size: int, base_config: RolloutProcessorConfig): self.num_runs = num_runs self.micro_batch_size = micro_batch_size @@ -288,14 +304,16 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz # 2. Start Workers # If we have separate limits, we need enough workers to saturate both stages num_workers = self.max_concurrent_rollouts - if self.max_concurrent_evaluations: - num_workers += self.max_concurrent_evaluations workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] # 3. Wait for completion await self.queue.join() + # Wait for background evaluations to finish + if self.background_tasks: + await asyncio.gather(*self.background_tasks, return_exceptions=True) + # 4. Cleanup for w in workers: w.cancel() @@ -314,16 +332,22 @@ async def execute_priority_rollouts( config: RolloutProcessorConfig, max_concurrent_rollouts: int, active_logger: DatasetLogger, - eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], + eval_executor: TestFunction, + max_concurrent_evaluations: int = 96, + mode: str = "pointwise", mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, - max_concurrent_evaluations: Optional[int] = None, + evaluation_test_kwargs: Dict[str, Any] = {}, ): scheduler = PriorityRolloutScheduler( rollout_processor=rollout_processor, max_concurrent_rollouts=max_concurrent_rollouts, active_logger=active_logger, eval_executor=eval_executor, - mini_batch_data_buffer=mini_batch_data_buffer, - max_concurrent_evaluations=max_concurrent_evaluations + output_buffer=mini_batch_data_buffer, + max_concurrent_evaluations=max_concurrent_evaluations, + rollout_n=num_runs, + mode=mode, + in_group_microbatch_size=micro_batch_size, + evaluation_test_kwargs=evaluation_test_kwargs, ) return await scheduler.run(dataset, num_runs, micro_batch_size, config) diff --git a/eval_protocol/pytest/validate_signature.py b/eval_protocol/pytest/validate_signature.py index 500de649..fa4cc752 100644 --- a/eval_protocol/pytest/validate_signature.py +++ b/eval_protocol/pytest/validate_signature.py @@ -53,8 +53,6 @@ def validate_signature( # validate that the function has a return type of List[EvaluationRow] if not _is_list_of_evaluation_row(signature.return_annotation): # pyright: ignore[reportAny] raise ValueError("In groupwise mode, your eval function must return a list of EvaluationRow instances") - if completion_params is not None and len(completion_params) < 2: - raise ValueError("In groupwise mode, you must provide at least 2 completion parameters") else: # all mode: function should accept input_dataset and model if "rows" not in signature.parameters: diff --git a/tests/pytest/test_rollout_scheduler.py b/tests/pytest/test_rollout_scheduler.py new file mode 100644 index 00000000..6526aeb0 --- /dev/null +++ b/tests/pytest/test_rollout_scheduler.py @@ -0,0 +1,50 @@ +from eval_protocol.pytest import evaluation_test, SingleTurnRolloutProcessor +from eval_protocol.models import EvaluationRow, Message, EvaluateResult, InputMetadata +from typing import List + + +@evaluation_test( + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + input_rows=[ + [ + EvaluationRow( + messages=[ + Message(role="system", content=f"You are a helpful assistant, and this is row {i}"), + Message(role="user", content="What is the capital of France?"), + ], + input_metadata=InputMetadata(row_id=f"row-{i}"), + ) + for i in range(10) + ] + ], + rollout_processor=SingleTurnRolloutProcessor(), + num_runs=4, + mode="pointwise", +) +def test_rollout_scheduler(row: EvaluationRow) -> EvaluationRow: + row.evaluation_result = EvaluateResult(score=0.5, reason="Dummy evaluation result") + return row + + +@evaluation_test( + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + input_rows=[ + [ + EvaluationRow( + messages=[ + Message(role="system", content=f"You are a helpful assistant, and this is row {i}"), + Message(role="user", content="What is the capital of France?"), + ], + input_metadata=InputMetadata(row_id=f"row-{i}"), + ) + for i in range(10) + ] + ], + rollout_processor=SingleTurnRolloutProcessor(), + num_runs=4, + mode="groupwise", +) +def test_rollout_scheduler_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]: + for i,row in enumerate(rows): + row.evaluation_result = EvaluateResult(score=0.1 * i, reason="Dummy evaluation result") + return rows \ No newline at end of file diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py index b5778a68..d4f5f8e1 100644 --- a/tests/test_priority_scheduler.py +++ b/tests/test_priority_scheduler.py @@ -1,10 +1,10 @@ import pytest import asyncio import time -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, patch from typing import List, Union -from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata +from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, EvaluateResult from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -57,28 +57,35 @@ async def test_scheduler_basic_execution( micro_batch_size = 1 # Mock rollout processor with delay - async def delayed_rollout(rows, config, run_idx): + async def delayed_rollout(processor, rows, config, run_idx): await asyncio.sleep(0.01) for row in rows: yield row - mock_processor = MagicMock() - mock_processor.side_effect = delayed_rollout # This is wrong usage for call, rollout_processor is passed as instance - # But wait, PriorityRolloutScheduler calls rollout_processor_with_retry which calls processor.process_batch or similar? - # Looking at code: rollout_processor_with_retry(self.rollout_processor, ...) - # rollout_processor_with_retry expects the processor instance. - - # Let's look at how rollout_processor_with_retry is implemented or usage. - # Assuming rollout_processor is an object with a method or it's a callable? - # In priority_scheduler.py: rollout_processor_with_retry(self.rollout_processor, ...) - - # Let's actually mock rollout_processor_with_retry since we want to test the scheduler logic, - # not the processor retry logic. - # But we can't easily mock the import inside the module without patching. - pass + async def mock_eval(row): + row.evaluation_result = EvaluateResult(score=1.0, is_score_valid=True) + return row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=delayed_rollout): + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=2, + active_logger=mock_logger, + eval_executor=mock_eval, + max_concurrent_evaluations=2, + rollout_n=num_runs, + in_group_microbatch_size=micro_batch_size + ) + + results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + + assert len(results) == 5 * num_runs + for res in results: + assert res.evaluation_result is not None + assert res.evaluation_result.score == 1.0 -# We will rely on patching 'eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry' -from unittest.mock import patch @pytest.mark.asyncio async def test_concurrency_control( @@ -118,6 +125,7 @@ async def mock_rollout_gen(processor, rows, config, run_idx): async with rollout_lock: active_rollouts -= 1 + # Use a real async function for eval to work with execute_pytest properly async def mock_eval(row): nonlocal active_evals, max_active_evals_seen async with eval_lock: @@ -132,7 +140,6 @@ async def mock_eval(row): return row with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): - mock_eval_executor.side_effect = mock_eval # Mock processor instance (can be anything since we patched the wrapper) processor_instance = MagicMock() @@ -141,8 +148,10 @@ async def mock_eval(row): rollout_processor=processor_instance, max_concurrent_rollouts=max_rollouts, active_logger=mock_logger, - eval_executor=mock_eval_executor, - max_concurrent_evaluations=max_evals + eval_executor=mock_eval, + max_concurrent_evaluations=max_evals, + rollout_n=num_runs, + in_group_microbatch_size=micro_batch_size ) await scheduler.run(dataset, num_runs, micro_batch_size, base_config) @@ -152,9 +161,8 @@ async def mock_eval(row): assert max_active_evals_seen <= max_evals, f"Eval concurrency exceeded: {max_active_evals_seen} > {max_evals}" # Verify everything ran - # 10 rows * 1 run = 10 rollouts called - # 10 evaluations - assert mock_eval_executor.call_count == 10 + # 10 rows * 1 run = 10 results + assert len(scheduler.results) == 10 @pytest.mark.asyncio async def test_priority_scheduling( @@ -162,25 +170,6 @@ async def test_priority_scheduling( ): """ Test that subsequent micro-batches are prioritized. - This is tricky to test deterministically with asyncio, but we can try to observe order - or ensure that a task that spawns new parts gets priority. - - We'll simulate a case where we have 2 samples, each needing 2 micro-batches. - We want to see if Sample 1 Batch 2 runs before Sample 2 Batch 1 is finished if possible, - but actually the scheduler puts Sample 1 Batch 2 with Priority 0 (High) and Sample 2 Batch 1 starts with Priority 1 (Low). - - If we limit concurrency to 1, we should see: - S1_B1 -> S1_B2 -> S2_B1 -> S2_B2 - - Wait, if concurrency is 1: - 1. Queue: [S1_B1 (Low), S2_B1 (Low)] - 2. Worker picks S1_B1. Queue: [S2_B1 (Low)] - 3. S1_B1 finishes. Puts S1_B2 (High). Queue: [S1_B2 (High), S2_B1 (Low)] - 4. Worker picks S1_B2. Queue: [S2_B1 (Low)] - 5. S1_B2 finishes. Queue: [S2_B1 (Low)] - 6. Worker picks S2_B1. ... - - So yes, strictly sequential per sample if concurrency=1. """ dataset = [create_mock_row(f"row-{i}") for i in range(2)] num_runs = 2 @@ -198,26 +187,24 @@ async def mock_eval(row): return row with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): - mock_eval_executor.side_effect = mock_eval processor_instance = MagicMock() scheduler = PriorityRolloutScheduler( rollout_processor=processor_instance, max_concurrent_rollouts=1, # Force serial execution to test priority active_logger=mock_logger, - eval_executor=mock_eval_executor, + eval_executor=mock_eval, + max_concurrent_evaluations=1, + rollout_n=num_runs, + in_group_microbatch_size=micro_batch_size ) await scheduler.run(dataset, num_runs, micro_batch_size, base_config) # Expected order: row-0_run_0, row-0_run_1, row-1_run_0, row-1_run_1 - # Or at least row-0_run_1 should come before row-1_run_0 finishes if parallel? - # With concurrency 1, it should be strictly: - # row-0 run 0 - # row-0 run 1 (high priority injected) - # row-1 run 0 - # row-1 run 1 - + # Note: Since row-0_run_0 finishes, it schedules row-0_run_1 with HIGH priority (0). + # row-1_run_0 is in queue with LOW priority (1). + # So row-0_run_1 should run before row-1_run_0. expected = [ "row-0_run_0", "row-0_run_1", @@ -237,7 +224,8 @@ async def test_worker_scaling( dataset = [create_mock_row("row-0")] max_rollouts = 5 max_evals = 3 - expected_workers = max_rollouts + max_evals + # Updated expectation: workers only scale with rollout concurrency now + expected_workers = max_rollouts worker_start_count = 0 @@ -272,7 +260,9 @@ async def schedule_dataset(self, *args): max_concurrent_rollouts=max_rollouts, active_logger=mock_logger, eval_executor=mock_eval_executor, - max_concurrent_evaluations=max_evals + max_concurrent_evaluations=max_evals, + rollout_n=1, + in_group_microbatch_size=1 ) await scheduler.run(dataset, 1, 1, base_config) @@ -303,8 +293,6 @@ async def mock_eval(rows): async def mock_rollout_gen(processor, rows, config, run_idx): for row in rows: yield row - - mock_eval_executor.side_effect = mock_eval with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): processor_instance = MagicMock() @@ -313,20 +301,22 @@ async def mock_rollout_gen(processor, rows, config, run_idx): rollout_processor=processor_instance, max_concurrent_rollouts=1, active_logger=mock_logger, - eval_executor=mock_eval_executor, - mode="groupwise" + eval_executor=mock_eval, + max_concurrent_evaluations=1, + mode="groupwise", + rollout_n=num_runs, + in_group_microbatch_size=micro_batch_size ) results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config) # Verify evaluation was called EXACTLY ONCE - assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}" - - # Verify it was called with ALL 4 rows - evaluated_rows = eval_calls[0] - assert len(evaluated_rows) == 4, f"Expected 4 rows in group eval, got {len(evaluated_rows)}" - - # Verify results contains all 4 rows - assert len(results) == 4 - - + assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}" + + # Verify it was called with ALL 4 rows + evaluated_rows = eval_calls[0] + assert len(evaluated_rows) == 4, f"Expected 4 rows in group eval, got {len(evaluated_rows)}" + + # Verify results contains all 4 runs (returned from eval) + # Note: eval returns a list of 4 rows. scheduler.results extends this list. + assert len(results) == 4 From b556f4ea4f68cb14a614b6f68511e9891b751e92 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 5 Dec 2025 10:38:44 -0800 Subject: [PATCH 05/11] add --- eval_protocol/pytest/buffer.py | 8 ++++---- eval_protocol/pytest/evaluation_test.py | 13 ++++++------- eval_protocol/pytest/priority_scheduler.py | 19 +++++++++---------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/eval_protocol/pytest/buffer.py b/eval_protocol/pytest/buffer.py index 51771e62..84436a2b 100644 --- a/eval_protocol/pytest/buffer.py +++ b/eval_protocol/pytest/buffer.py @@ -5,14 +5,14 @@ from eval_protocol.models import EvaluationRow -class MiniBatchDataBuffer: +class MicroBatchDataBuffer: """ Buffers evaluation results and writes them to disk in minibatches. Waits for all runs of a sample to complete before considering it ready and flush to disk. """ - def __init__(self, num_runs: int, minibatch_size: int, output_path_template: str): + def __init__(self, num_runs: int, batch_size: int, output_path_template: str): self.num_runs = num_runs - self.minibatch_size = minibatch_size + self.batch_size = batch_size self.output_path_template = output_path_template self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow] self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]] @@ -37,7 +37,7 @@ async def add_result(self, row: EvaluationRow): completed_rows = self.pending_samples.pop(row_id) self.completed_samples_buffer.append(completed_rows) - if len(self.completed_samples_buffer) >= self.minibatch_size: + if len(self.completed_samples_buffer) >= self.batch_size: await self._flush_unsafe() async def _flush_unsafe(self): diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 5dc5ca5e..eb3d70c7 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -70,7 +70,7 @@ from eval_protocol.log_utils.init import init_external_logging_from_env from eval_protocol.log_utils.rollout_context import rollout_logging_context from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab -from eval_protocol.pytest.buffer import MiniBatchDataBuffer +from eval_protocol.pytest.buffer import MicroBatchDataBuffer from ..common_utils import load_jsonl @@ -411,17 +411,16 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo ) if use_priority_scheduler: - print("Using priority scheduler") - minibatch_output_size = os.environ.get("EP_MINI_BATCH_OUTPUT_SIZE", None) + microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None) output_dir = os.environ.get("EP_OUTPUT_DIR", None) - if minibatch_output_size and output_dir: - output_buffer = MiniBatchDataBuffer(num_runs=num_runs, minibatch_size=int(minibatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl")) + if microbatch_output_size and output_dir: + output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl")) else: output_buffer = None + priority_results = await execute_priority_rollouts( dataset=data, num_runs=num_runs, - micro_batch_size=int(os.environ.get("EP_MICRO_BATCH_SIZE", "1")), rollout_processor=rollout_processor, config=config, max_concurrent_rollouts=max_concurrent_rollouts, @@ -429,7 +428,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo eval_executor=test_func, max_concurrent_evaluations=max_concurrent_evaluations, mode=mode, - mini_batch_data_buffer=output_buffer, + micro_batch_data_buffer=output_buffer, evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, ) diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index f97caff9..a50dd358 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -8,7 +8,7 @@ from eval_protocol.pytest.types import RolloutProcessorConfig, TestFunction from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry, add_cost_metrics -from eval_protocol.pytest.buffer import MiniBatchDataBuffer +from eval_protocol.pytest.buffer import MicroBatchDataBuffer from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.human_id import generate_id from eval_protocol.log_utils.rollout_context import rollout_logging_context @@ -49,10 +49,10 @@ def __init__( active_logger: DatasetLogger, max_concurrent_evaluations: int, eval_executor: TestFunction, # Callback to run evaluation - output_buffer: Optional[MiniBatchDataBuffer] = None, + output_buffer: Optional[MicroBatchDataBuffer] = None, rollout_n: int = 0, mode: str = "pointwise", - in_group_microbatch_size: int = 0, # for one sample, how many runs to execute at the same time + in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time evaluation_test_kwargs: Dict[str, Any] = {}, ): self.rollout_processor = rollout_processor @@ -77,7 +77,7 @@ def __init__( self.background_tasks = set() # run evaluations in the background asynchronously self.rollout_n = rollout_n - self.in_group_microbatch_size = in_group_microbatch_size if in_group_microbatch_size > 0 else rollout_n + self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n self.evaluation_test_kwargs = evaluation_test_kwargs async def schedule_dataset( @@ -91,7 +91,7 @@ async def schedule_dataset( for i, row in enumerate(dataset): # Calculate ranges for the first in-group minibatch batch_start = 0 - batch_end = min(self.in_group_microbatch_size, self.rollout_n) + batch_end = min(self.in_group_minibatch_size, self.rollout_n) run_indices = list(range(batch_start, batch_end)) # Initial priority: Low (1), ordered by dataset index @@ -243,7 +243,7 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): next_start = last_run_idx + 1 if next_start < self.rollout_n: - next_end = min(next_start + self.in_group_microbatch_size, self.rollout_n) + next_end = min(next_start + self.in_group_minibatch_size, self.rollout_n) next_indices = list(range(next_start, next_end)) new_history = task.history + current_batch_history_updates @@ -327,7 +327,6 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz async def execute_priority_rollouts( dataset: List[EvaluationRow], num_runs: int, - micro_batch_size: int, rollout_processor: RolloutProcessor, config: RolloutProcessorConfig, max_concurrent_rollouts: int, @@ -335,7 +334,7 @@ async def execute_priority_rollouts( eval_executor: TestFunction, max_concurrent_evaluations: int = 96, mode: str = "pointwise", - mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, + micro_batch_data_buffer: Optional[MicroBatchDataBuffer] = None, evaluation_test_kwargs: Dict[str, Any] = {}, ): scheduler = PriorityRolloutScheduler( @@ -343,11 +342,11 @@ async def execute_priority_rollouts( max_concurrent_rollouts=max_concurrent_rollouts, active_logger=active_logger, eval_executor=eval_executor, - output_buffer=mini_batch_data_buffer, + output_buffer=micro_batch_data_buffer, max_concurrent_evaluations=max_concurrent_evaluations, rollout_n=num_runs, mode=mode, - in_group_microbatch_size=micro_batch_size, + in_group_minibatch_size=(num_runs // 2), evaluation_test_kwargs=evaluation_test_kwargs, ) return await scheduler.run(dataset, num_runs, micro_batch_size, config) From 5fc935cc59136715ae9d91101f0e7b1668a69970 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 5 Dec 2025 10:44:01 -0800 Subject: [PATCH 06/11] fix --- eval_protocol/pytest/priority_scheduler.py | 5 ++--- pytest.ini | 21 --------------------- tests/test_priority_scheduler.py | 20 ++++++++++---------- 3 files changed, 12 insertions(+), 34 deletions(-) delete mode 100644 pytest.ini diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index a50dd358..6dec8ead 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -294,9 +294,8 @@ def _post_process_result(self, res: EvaluationRow): pass self.active_logger.log(res) - async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_size: int, base_config: RolloutProcessorConfig): + async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: RolloutProcessorConfig): self.num_runs = num_runs - self.micro_batch_size = micro_batch_size # 1. Schedule initial tasks await self.schedule_dataset(dataset, base_config) @@ -349,4 +348,4 @@ async def execute_priority_rollouts( in_group_minibatch_size=(num_runs // 2), evaluation_test_kwargs=evaluation_test_kwargs, ) - return await scheduler.run(dataset, num_runs, micro_batch_size, config) + return await scheduler.run(dataset, num_runs, config) diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index b3c84ce1..00000000 --- a/pytest.ini +++ /dev/null @@ -1,21 +0,0 @@ -[pytest] -markers = - asyncio -asyncio_mode = auto -asyncio_default_fixture_loop_scope = function -testpaths = tests ./eval_protocol/quickstart -python_files = test_*.py llm_judge_*.py -python_classes = Test* -python_functions = test_* -# Configure stdout/stderr capture for debugging -addopts = -s --tb=short -q -# Alternative: disable capture completely for debugging -# addopts = -s --tb=short --capture=no -filterwarnings = - ignore::UserWarning:pydantic.* - ignore::DeprecationWarning:pydantic.* - ignore:.*PydanticSerializationUnexpectedValue.*:UserWarning - ignore:.*Support for class-based.*config.*:DeprecationWarning - ignore:.*serializer warnings.*:UserWarning - ignore:.*Pydantic.*:UserWarning - ignore:.*Pydantic.*:DeprecationWarning diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py index d4f5f8e1..27e748eb 100644 --- a/tests/test_priority_scheduler.py +++ b/tests/test_priority_scheduler.py @@ -76,10 +76,10 @@ async def mock_eval(row): eval_executor=mock_eval, max_concurrent_evaluations=2, rollout_n=num_runs, - in_group_microbatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size ) - results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + results = await scheduler.run(dataset, num_runs, base_config) assert len(results) == 5 * num_runs for res in results: @@ -151,10 +151,10 @@ async def mock_eval(row): eval_executor=mock_eval, max_concurrent_evaluations=max_evals, rollout_n=num_runs, - in_group_microbatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size ) - await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + await scheduler.run(dataset, num_runs, base_config) # Verify limits were respected assert max_active_rollouts_seen <= max_rollouts, f"Rollout concurrency exceeded: {max_active_rollouts_seen} > {max_rollouts}" @@ -196,10 +196,10 @@ async def mock_eval(row): eval_executor=mock_eval, max_concurrent_evaluations=1, rollout_n=num_runs, - in_group_microbatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size ) - await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + await scheduler.run(dataset, num_runs, base_config) # Expected order: row-0_run_0, row-0_run_1, row-1_run_0, row-1_run_1 # Note: Since row-0_run_0 finishes, it schedules row-0_run_1 with HIGH priority (0). @@ -262,10 +262,10 @@ async def schedule_dataset(self, *args): eval_executor=mock_eval_executor, max_concurrent_evaluations=max_evals, rollout_n=1, - in_group_microbatch_size=1 + in_group_minibatch_size=1 ) - await scheduler.run(dataset, 1, 1, base_config) + await scheduler.run(dataset, 1, base_config) assert worker_start_count == expected_workers @@ -305,10 +305,10 @@ async def mock_rollout_gen(processor, rows, config, run_idx): max_concurrent_evaluations=1, mode="groupwise", rollout_n=num_runs, - in_group_microbatch_size=micro_batch_size + in_group_minibatch_size=micro_batch_size ) - results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config) + results = await scheduler.run(dataset, num_runs, base_config) # Verify evaluation was called EXACTLY ONCE assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}" From 9219921f682b5020def0a1054476ae2b96fe8f1d Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 5 Dec 2025 10:44:57 -0800 Subject: [PATCH 07/11] put it back --- pytest.ini | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 pytest.ini diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..bb315e29 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,21 @@ +[pytest] +markers = + asyncio +asyncio_mode = auto +asyncio_default_fixture_loop_scope = function +testpaths = tests ./eval_protocol/quickstart +python_files = test_*.py llm_judge_*.py +python_classes = Test* +python_functions = test_* +# Configure stdout/stderr capture for debugging +addopts = -s --tb=short -q +# Alternative: disable capture completely for debugging +# addopts = -s --tb=short --capture=no +filterwarnings = + ignore::UserWarning:pydantic.* + ignore::DeprecationWarning:pydantic.* + ignore:.*PydanticSerializationUnexpectedValue.*:UserWarning + ignore:.*Support for class-based.*config.*:DeprecationWarning + ignore:.*serializer warnings.*:UserWarning + ignore:.*Pydantic.*:UserWarning + ignore:.*Pydantic.*:DeprecationWarning \ No newline at end of file From fae3150a1b7b4202c83c5fa9e8d8c6dd8f77d59f Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 5 Dec 2025 10:45:26 -0800 Subject: [PATCH 08/11] add --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index bb315e29..b3c84ce1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,4 +18,4 @@ filterwarnings = ignore:.*Support for class-based.*config.*:DeprecationWarning ignore:.*serializer warnings.*:UserWarning ignore:.*Pydantic.*:UserWarning - ignore:.*Pydantic.*:DeprecationWarning \ No newline at end of file + ignore:.*Pydantic.*:DeprecationWarning From f785514ca041a12244d1f6eb3d12457ca64550e5 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 5 Dec 2025 15:30:25 -0800 Subject: [PATCH 09/11] add postprocess --- eval_protocol/pytest/evaluation_test.py | 11 +++++++++++ tests/pytest/test_rollout_scheduler.py | 1 - 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index eb3d70c7..f6ed22f0 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -438,6 +438,17 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo all_results[run_idx].append(res) processed_rows_in_run.append(res) + postprocess( + all_results, + aggregation_method, + passed_threshold, + active_logger, + mode, + completion_params, # pyright: ignore[reportArgumentType] + test_func.__name__, + num_runs, + time.perf_counter() - experiment_start_time, + ) else: async def execute_run(run_idx: int, config: RolloutProcessorConfig): diff --git a/tests/pytest/test_rollout_scheduler.py b/tests/pytest/test_rollout_scheduler.py index 6526aeb0..1a1ff7a9 100644 --- a/tests/pytest/test_rollout_scheduler.py +++ b/tests/pytest/test_rollout_scheduler.py @@ -2,7 +2,6 @@ from eval_protocol.models import EvaluationRow, Message, EvaluateResult, InputMetadata from typing import List - @evaluation_test( completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_rows=[ From 81fbc701646c60e131ddc2c0633e9c8f3b8f7bce Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 8 Dec 2025 15:38:25 -0800 Subject: [PATCH 10/11] resolve comments and fix bugs --- eval_protocol/pytest/evaluation_test.py | 53 ++++++++++++---------- eval_protocol/pytest/priority_scheduler.py | 11 ++--- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index f6ed22f0..dde9a7db 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -418,19 +418,23 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo else: output_buffer = None - priority_results = await execute_priority_rollouts( - dataset=data, - num_runs=num_runs, - rollout_processor=rollout_processor, - config=config, - max_concurrent_rollouts=max_concurrent_rollouts, - active_logger=active_logger, - eval_executor=test_func, - max_concurrent_evaluations=max_concurrent_evaluations, - mode=mode, - micro_batch_data_buffer=output_buffer, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, - ) + try: + priority_results = await execute_priority_rollouts( + dataset=data, + num_runs=num_runs, + rollout_processor=rollout_processor, + config=config, + max_concurrent_rollouts=max_concurrent_rollouts, + active_logger=active_logger, + eval_executor=test_func, + max_concurrent_evaluations=max_concurrent_evaluations, + mode=mode, + micro_batch_data_buffer=output_buffer, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) + finally: + if output_buffer: + await output_buffer.close() for res in priority_results: run_idx = (res.execution_metadata.extra or {}).get("run_index", 0) @@ -438,17 +442,18 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo all_results[run_idx].append(res) processed_rows_in_run.append(res) - postprocess( - all_results, - aggregation_method, - passed_threshold, - active_logger, - mode, - completion_params, # pyright: ignore[reportArgumentType] - test_func.__name__, - num_runs, - time.perf_counter() - experiment_start_time, - ) + + postprocess( + all_results, + aggregation_method, + passed_threshold, + active_logger, + mode, + completion_params, # pyright: ignore[reportArgumentType] + test_func.__name__, + num_runs, + time.perf_counter() - experiment_start_time, + ) else: async def execute_run(run_idx: int, config: RolloutProcessorConfig): diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 6dec8ead..eaddacc5 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from collections import defaultdict from dataclasses import dataclass, field @@ -67,7 +68,6 @@ def __init__( self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() # Concurrency Control - self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts) self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) # Results storage @@ -112,16 +112,13 @@ async def worker(self): Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any). """ while True: - try: - # Get a task from the priority queue - task: RolloutTask = await self.queue.get() - except asyncio.QueueEmpty: - break + # Get a task from the priority queue + task: RolloutTask = await self.queue.get() try: await self._process_task(task) except Exception as e: - print(f"Error processing task for row {task.row.input_metadata.row_id}: {e}") + logging.error(f"Error processing task for row {task.row.input_metadata.row_id}: {e}", exc_info=True) finally: self.queue.task_done() From b29af9654b09d9ea9c9ad9d437255b091c27f40e Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 8 Dec 2025 22:08:34 -0800 Subject: [PATCH 11/11] fix --- eval_protocol/pytest/buffer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/eval_protocol/pytest/buffer.py b/eval_protocol/pytest/buffer.py index 84436a2b..88e2f2a5 100644 --- a/eval_protocol/pytest/buffer.py +++ b/eval_protocol/pytest/buffer.py @@ -71,6 +71,12 @@ async def close(self): Flush any remaining samples in the buffer. """ async with self.lock: + # Also flush pending (incomplete) samples to avoid data loss + if self.pending_samples: + for rows in self.pending_samples.values(): + self.completed_samples_buffer.append(rows) + self.pending_samples.clear() + if self.completed_samples_buffer: await self._flush_unsafe()