diff --git a/eval_protocol/pytest/buffer.py b/eval_protocol/pytest/buffer.py new file mode 100644 index 00000000..88e2f2a5 --- /dev/null +++ b/eval_protocol/pytest/buffer.py @@ -0,0 +1,82 @@ +import asyncio +import os +from collections import defaultdict +from typing import List, Dict + +from eval_protocol.models import EvaluationRow + +class MicroBatchDataBuffer: + """ + Buffers evaluation results and writes them to disk in minibatches. + Waits for all runs of a sample to complete before considering it ready and flush to disk. + """ + def __init__(self, num_runs: int, batch_size: int, output_path_template: str): + self.num_runs = num_runs + self.batch_size = batch_size + self.output_path_template = output_path_template + self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow] + self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]] + self.batch_index = 0 + self.lock = asyncio.Lock() + + async def add_result(self, row: EvaluationRow): + """ + Add a single evaluation result. + Thread-safe/Coroutine-safe. + """ + async with self.lock: + row_id = row.input_metadata.row_id + if not row_id: + # Should not happen in valid EP workflow, unique row_id is required to group things together properly + return + + self.pending_samples[row_id].append(row) + + if len(self.pending_samples[row_id]) >= self.num_runs: + # Sample completed (all runs finished) + completed_rows = self.pending_samples.pop(row_id) + self.completed_samples_buffer.append(completed_rows) + + if len(self.completed_samples_buffer) >= self.batch_size: + await self._flush_unsafe() + + async def _flush_unsafe(self): + """ + not thread safe, assumes lock is held by called + """ + if not self.completed_samples_buffer: + return + + if "{index}" in self.output_path_template: + output_path = self.output_path_template.format(index=self.batch_index) + mode = "w" + else: + output_path = self.output_path_template + mode = "a" # Append if no index placeholder + + # Ensure directory exists + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + + # Write flattened rows + with open(output_path, mode) as f: + for sample_rows in self.completed_samples_buffer: + for row in sample_rows: + f.write(row.model_dump_json() + "\n") + + self.completed_samples_buffer = [] + self.batch_index += 1 + + async def close(self): + """ + Flush any remaining samples in the buffer. + """ + async with self.lock: + # Also flush pending (incomplete) samples to avoid data loss + if self.pending_samples: + for rows in self.pending_samples.values(): + self.completed_samples_buffer.append(rows) + self.pending_samples.clear() + + if self.completed_samples_buffer: + await self._flush_unsafe() + diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 67cc096e..dde9a7db 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import asyncio import inspect import os @@ -25,6 +26,7 @@ from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper from eval_protocol.pytest.evaluation_test_postprocess import postprocess from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling +from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts from eval_protocol.pytest.generate_parameter_combinations import ( ParameterizedTestKwargs, generate_parameter_combinations, @@ -68,7 +70,7 @@ from eval_protocol.log_utils.init import init_external_logging_from_env from eval_protocol.log_utils.rollout_context import rollout_logging_context from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab - +from eval_protocol.pytest.buffer import MicroBatchDataBuffer from ..common_utils import load_jsonl @@ -401,59 +403,207 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo rollout_processor.setup() - async def execute_run(run_idx: int, config: RolloutProcessorConfig): - nonlocal all_results - - # Regenerate outputs each run by deep-copying the pristine dataset - # so model responses are not reused across runs. - run_id = generate_id() - fresh_dataset = [r.model_copy(deep=True) for r in data] - - # apply new run_id to fresh_dataset - for row in fresh_dataset: - row.execution_metadata.run_id = run_id + use_priority_scheduler = ( + ( + os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1" + and not isinstance(rollout_processor, MCPGymRolloutProcessor) + ) + ) - # generate new rollout_id for each row - for row in fresh_dataset: - row.execution_metadata.rollout_id = generate_id() + if use_priority_scheduler: + microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None) + output_dir = os.environ.get("EP_OUTPUT_DIR", None) + if microbatch_output_size and output_dir: + output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl")) + else: + output_buffer = None + + try: + priority_results = await execute_priority_rollouts( + dataset=data, + num_runs=num_runs, + rollout_processor=rollout_processor, + config=config, + max_concurrent_rollouts=max_concurrent_rollouts, + active_logger=active_logger, + eval_executor=test_func, + max_concurrent_evaluations=max_concurrent_evaluations, + mode=mode, + micro_batch_data_buffer=output_buffer, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) + finally: + if output_buffer: + await output_buffer.close() + + for res in priority_results: + run_idx = (res.execution_metadata.extra or {}).get("run_index", 0) + if run_idx < len(all_results): + all_results[run_idx].append(res) + + processed_rows_in_run.append(res) - # log the fresh_dataset - for row in fresh_dataset: - active_logger.log(row) - processed_rows_in_run.append(row) + postprocess( + all_results, + aggregation_method, + passed_threshold, + active_logger, + mode, + completion_params, # pyright: ignore[reportArgumentType] + test_func.__name__, + num_runs, + time.perf_counter() - experiment_start_time, + ) - # prepare parallel eval helper function - semaphore = asyncio.Semaphore(max_concurrent_evaluations) + else: + async def execute_run(run_idx: int, config: RolloutProcessorConfig): + nonlocal all_results + + # Regenerate outputs each run by deep-copying the pristine dataset + # so model responses are not reused across runs. + run_id = generate_id() + fresh_dataset = [r.model_copy(deep=True) for r in data] + + # apply new run_id to fresh_dataset + for row in fresh_dataset: + row.execution_metadata.run_id = run_id + + # generate new rollout_id for each row + for row in fresh_dataset: + row.execution_metadata.rollout_id = generate_id() + + # log the fresh_dataset + for row in fresh_dataset: + active_logger.log(row) + processed_rows_in_run.append(row) + + # prepare parallel eval helper function + semaphore = asyncio.Semaphore(max_concurrent_evaluations) + + async def _execute_pointwise_eval_with_semaphore( + row: EvaluationRow, + ) -> EvaluationRow: + async with semaphore: + evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} + async with rollout_logging_context( + row.execution_metadata.rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + ): + result = await execute_pytest_with_exception_handling( + test_func=test_func, + evaluation_test_kwargs=evaluation_test_kwargs, + processed_row=row, + ) + if not isinstance(result, EvaluationRow): + raise ValueError( + f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + ) + return result - async def _execute_pointwise_eval_with_semaphore( - row: EvaluationRow, - ) -> EvaluationRow: - async with semaphore: - evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} - async with rollout_logging_context( - row.execution_metadata.rollout_id or "", - experiment_id=experiment_id, - run_id=run_id, + async def _execute_groupwise_eval_with_semaphore( + rows: list[EvaluationRow], + ) -> list[EvaluationRow]: + async with semaphore: + evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} + primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None + group_rollout_ids = [ + r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + ] + async with rollout_logging_context( + primary_rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + rollout_ids=group_rollout_ids or None, + ): + results = await execute_pytest_with_exception_handling( + test_func=test_func, + evaluation_test_kwargs=evaluation_test_kwargs, + processed_dataset=rows, + ) + if not isinstance(results, list): + raise ValueError( + f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + ) + return results + + if mode == "pointwise": + # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution + pointwise_tasks: list[asyncio.Task[EvaluationRow]] = [] + # Use wrapper that handles retry logic internally + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, run_idx ): - result = await execute_pytest_with_exception_handling( - test_func=test_func, - evaluation_test_kwargs=evaluation_test_kwargs, - processed_row=row, + pointwise_tasks.append( + asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row)) ) - if not isinstance(result, EvaluationRow): - raise ValueError( - f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + + # Run evaluation tasks with progress bar + results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx) + + all_results[run_idx] = results + elif mode == "groupwise": + # rollout all the completion_params for the same row at once, and then send the output to the test_func + row_groups = defaultdict(list) # key: row_id, value: list of rollout_result + tasks: list[asyncio.Task[list[EvaluationRow]]] = [] + # completion_groups = [] + for idx, cp in enumerate(original_completion_params): + config = RolloutProcessorConfig( + completion_params=cp if cp is not None else {}, + mcp_config_path=mcp_config_path or "", + server_script_path=server_script_path, + steps=steps, + logger=active_logger, + semaphore=shared_semaphore, + kwargs=rollout_processor_kwargs or {}, ) - return result - - async def _execute_groupwise_eval_with_semaphore( - rows: list[EvaluationRow], - ) -> list[EvaluationRow]: - async with semaphore: - evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} - primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None + lst = [] + + async def _collect_result(config, lst): + result = [] + async for row in rollout_processor_with_retry( + rollout_processor, lst, config, run_idx + ): # pyright: ignore[reportUnknownArgumentType] + result.append(row) + return result + + for ori_row in fresh_dataset: + copied_row = ori_row.model_copy(deep=True) + # overwrite the rollout_id to the index of the completion_params + copied_row.execution_metadata.rollout_id = ( + str(ori_row.execution_metadata.rollout_id) + "_" + str(idx) + ) + copied_row.input_metadata.completion_params = cp if cp is not None else {} + lst.append(copied_row) + tasks.append(asyncio.create_task(_collect_result(config, lst))) + rollout_results = await asyncio.gather(*tasks) + for result in rollout_results: + for row in result: + row_groups[row.input_metadata.row_id].append(row) + tasks = [] + for _, rows in row_groups.items(): + tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) + results = [] + for task in tasks: + res = await task + results.extend(res) + all_results[run_idx] = results + else: + # Batch mode: collect all results first, then evaluate (no pipelining) + input_dataset = [] + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, run_idx + ): + input_dataset.append(row) + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func + primary_rollout_id = ( + input_dataset[0].execution_metadata.rollout_id if input_dataset else None + ) group_rollout_ids = [ - r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + r.execution_metadata.rollout_id + for r in input_dataset + if r.execution_metadata.rollout_id ] async with rollout_logging_context( primary_rollout_id or "", @@ -463,205 +613,113 @@ async def _execute_groupwise_eval_with_semaphore( ): results = await execute_pytest_with_exception_handling( test_func=test_func, - evaluation_test_kwargs=evaluation_test_kwargs, - processed_dataset=rows, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + processed_dataset=input_dataset, ) - if not isinstance(results, list): + if ( + results is None + or not isinstance(results, list) + or not all(isinstance(r, EvaluationRow) for r in results) + ): raise ValueError( f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." ) - return results - - if mode == "pointwise": - # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution - pointwise_tasks: list[asyncio.Task[EvaluationRow]] = [] - # Use wrapper that handles retry logic internally - async for row in rollout_processor_with_retry( - rollout_processor, fresh_dataset, config, run_idx - ): - pointwise_tasks.append( - asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row)) - ) - - # Run evaluation tasks with progress bar - results = await run_tasks_with_eval_progress(pointwise_tasks, run_idx) - - all_results[run_idx] = results - elif mode == "groupwise": - # rollout all the completion_params for the same row at once, and then send the output to the test_func - row_groups = defaultdict(list) # key: row_id, value: list of rollout_result - tasks: list[asyncio.Task[list[EvaluationRow]]] = [] - # completion_groups = [] - for idx, cp in enumerate(original_completion_params): - config = RolloutProcessorConfig( - completion_params=cp if cp is not None else {}, - mcp_config_path=mcp_config_path or "", - server_script_path=server_script_path, - steps=steps, - logger=active_logger, - semaphore=shared_semaphore, - kwargs=rollout_processor_kwargs or {}, - ) - lst = [] - - async def _collect_result(config, lst): - result = [] - async for row in rollout_processor_with_retry( - rollout_processor, lst, config, run_idx - ): # pyright: ignore[reportUnknownArgumentType] - result.append(row) - return result - - for ori_row in fresh_dataset: - copied_row = ori_row.model_copy(deep=True) - # overwrite the rollout_id to the index of the completion_params - copied_row.execution_metadata.rollout_id = ( - str(ori_row.execution_metadata.rollout_id) + "_" + str(idx) + if not results: + raise ValueError( + f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test." ) - copied_row.input_metadata.completion_params = cp if cp is not None else {} - lst.append(copied_row) - tasks.append(asyncio.create_task(_collect_result(config, lst))) - rollout_results = await asyncio.gather(*tasks) - for result in rollout_results: - for row in result: - row_groups[row.input_metadata.row_id].append(row) - tasks = [] - for _, rows in row_groups.items(): - tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) - results = [] - for task in tasks: - res = await task - results.extend(res) - all_results[run_idx] = results + all_results[run_idx] = results + + for r in results: + add_cost_metrics(r) + if r.eval_metadata is not None: + if r.rollout_status.is_error(): + r.eval_metadata.status = Status.error( + r.rollout_status.message, r.rollout_status.details + ) + elif not ( + r.eval_metadata.status and r.eval_metadata.status.code != Status.Code.RUNNING + ): + # if the eval_metadata status code has not been set to something else, consider it as finished + r.eval_metadata.status = Status.eval_finished() + # Optional debug print for assistant/tool sequence + if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": + try: + preview = [ + { + "role": m.role, + "len": len(m.content or "") if isinstance(m.content, str) else None, + "tool_calls": len(m.tool_calls or []) + if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) + else 0, + "tool_call_id": getattr(m, "tool_call_id", None), + "name": getattr(m, "name", None), + } + for m in r.messages + ] + print("[EP-Log] Row messages:", preview) + except Exception: + pass + active_logger.log(r) + + # if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs + # else, we execute runs in parallel + if isinstance(rollout_processor, MCPGymRolloutProcessor): + # For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts + for run_idx in range(num_runs): + task = asyncio.create_task(execute_run(run_idx, config)) + await task else: - # Batch mode: collect all results first, then evaluate (no pipelining) - input_dataset = [] - async for row in rollout_processor_with_retry( - rollout_processor, fresh_dataset, config, run_idx - ): - input_dataset.append(row) - # NOTE: we will still evaluate errored rows (give users control over this) - # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func - primary_rollout_id = ( - input_dataset[0].execution_metadata.rollout_id if input_dataset else None - ) - group_rollout_ids = [ - r.execution_metadata.rollout_id - for r in input_dataset - if r.execution_metadata.rollout_id + # For other processors, create all tasks at once and run in parallel + # Concurrency is now controlled by the shared semaphore in each rollout processor + await run_tasks_with_run_progress(execute_run, num_runs, config) + + experiment_duration_seconds = time.perf_counter() - experiment_start_time + + # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them + # rollout_id is used to differentiate the result from different completion_params + if mode == "groupwise": + results_by_group = [ + [[] for _ in range(num_runs)] for _ in range(len(original_completion_params)) ] - async with rollout_logging_context( - primary_rollout_id or "", - experiment_id=experiment_id, - run_id=run_id, - rollout_ids=group_rollout_ids or None, - ): - results = await execute_pytest_with_exception_handling( - test_func=test_func, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, - processed_dataset=input_dataset, - ) - if ( - results is None - or not isinstance(results, list) - or not all(isinstance(r, EvaluationRow) for r in results) - ): - raise ValueError( - f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + for i_run, result in enumerate(all_results): + for r in result: + completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess] + results_by_group[completion_param_idx][i_run].append(r) + for rollout_id, result in enumerate(results_by_group): + postprocess( + result, + aggregation_method, + passed_threshold, + active_logger, + mode, + original_completion_params[rollout_id], # pyright: ignore[reportArgumentType] + test_func.__name__, + num_runs, + experiment_duration_seconds, ) - if not results: - raise ValueError( - f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test." - ) - all_results[run_idx] = results - - for r in results: - add_cost_metrics(r) - if r.eval_metadata is not None: - if r.rollout_status.is_error(): - r.eval_metadata.status = Status.error( - r.rollout_status.message, r.rollout_status.details - ) - elif not ( - r.eval_metadata.status and r.eval_metadata.status.code != Status.Code.RUNNING - ): - # if the eval_metadata status code has not been set to something else, consider it as finished - r.eval_metadata.status = Status.eval_finished() - # Optional debug print for assistant/tool sequence - if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": - try: - preview = [ - { - "role": m.role, - "len": len(m.content or "") if isinstance(m.content, str) else None, - "tool_calls": len(m.tool_calls or []) - if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) - else 0, - "tool_call_id": getattr(m, "tool_call_id", None), - "name": getattr(m, "name", None), - } - for m in r.messages - ] - print("[EP-Log] Row messages:", preview) - except Exception: - pass - active_logger.log(r) - - # if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs - # else, we execute runs in parallel - if isinstance(rollout_processor, MCPGymRolloutProcessor): - # For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts - for run_idx in range(num_runs): - task = asyncio.create_task(execute_run(run_idx, config)) - await task - else: - # For other processors, create all tasks at once and run in parallel - # Concurrency is now controlled by the shared semaphore in each rollout processor - await run_tasks_with_run_progress(execute_run, num_runs, config) - - experiment_duration_seconds = time.perf_counter() - experiment_start_time - - if not all(r.evaluation_result is not None for run_results in all_results for r in run_results): - raise AssertionError( - "Some EvaluationRow instances are missing evaluation_result. " - "Your @evaluation_test function must set `row.evaluation_result`" - ) - - # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them - # rollout_id is used to differentiate the result from different completion_params - if mode == "groupwise": - results_by_group = [ - [[] for _ in range(num_runs)] for _ in range(len(original_completion_params)) - ] - for i_run, result in enumerate(all_results): - for r in result: - completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess] - results_by_group[completion_param_idx][i_run].append(r) - for rollout_id, result in enumerate(results_by_group): + else: postprocess( - result, + all_results, aggregation_method, passed_threshold, active_logger, mode, - original_completion_params[rollout_id], # pyright: ignore[reportArgumentType] + completion_params, # pyright: ignore[reportArgumentType] test_func.__name__, num_runs, experiment_duration_seconds, ) - else: - postprocess( - all_results, - aggregation_method, - passed_threshold, - active_logger, - mode, - completion_params, # pyright: ignore[reportArgumentType] - test_func.__name__, - num_runs, - experiment_duration_seconds, + + + + if not all(r.evaluation_result is not None for run_results in all_results for r in run_results): + raise AssertionError( + "Some EvaluationRow instances are missing evaluation_result. " + "Your @evaluation_test function must set `row.evaluation_result`" ) + except AssertionError: _log_eval_error( Status.eval_finished(), diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py new file mode 100644 index 00000000..eaddacc5 --- /dev/null +++ b/eval_protocol/pytest/priority_scheduler.py @@ -0,0 +1,348 @@ +import asyncio +import logging +import os +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, List, Dict, Optional, Union, Awaitable + +from eval_protocol.models import EvaluationRow, Status +from eval_protocol.pytest.types import RolloutProcessorConfig, TestFunction +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry, add_cost_metrics +from eval_protocol.pytest.buffer import MicroBatchDataBuffer +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger +from eval_protocol.human_id import generate_id +from eval_protocol.log_utils.rollout_context import rollout_logging_context +from eval_protocol.pytest.execution import execute_pytest_with_exception_handling + +ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1" + +@dataclass(order=True) +class RolloutTask: + """ + Represents a single unit of work for the worker pool. + Priority tuple structure: (status, row_index) + - status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample) + 1 = Low Priority (e.g., starting a new sample) + - row_index: Used to maintain dataset order for initial scheduling + """ + priority: tuple[int, int] + + # Payload (excluded from comparison) + row: EvaluationRow = field(compare=False) + run_indices: List[int] = field(compare=False) # Which runs to execute in this task + config: RolloutProcessorConfig = field(compare=False) + row_index: int = field(compare=False) # To track which sample this belongs to + + # History for speculation (injected from previous micro-batches) + history: List[str] = field(compare=False, default_factory=list) + +class PriorityRolloutScheduler: + """ + Manages a priority queue of rollout tasks and a pool of workers. + Ensures that once a sample starts processing, its subsequent micro-batches + are prioritized to complete the sample as quickly as possible. + """ + def __init__( + self, + rollout_processor: RolloutProcessor, + max_concurrent_rollouts: int, + active_logger: DatasetLogger, + max_concurrent_evaluations: int, + eval_executor: TestFunction, # Callback to run evaluation + output_buffer: Optional[MicroBatchDataBuffer] = None, + rollout_n: int = 0, + mode: str = "pointwise", + in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time + evaluation_test_kwargs: Dict[str, Any] = {}, + ): + self.rollout_processor = rollout_processor + self.max_concurrent_rollouts = max_concurrent_rollouts + self.max_concurrent_evaluations = max_concurrent_evaluations + self.active_logger = active_logger + self.eval_executor = eval_executor + self.output_buffer = output_buffer + self.mode = mode + + # Priority Queue: Stores RolloutTask + self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() + + # Concurrency Control + self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) + + # Results storage + self.results: List[EvaluationRow] = [] # for backward compatibility reason, we save all results here to return + self.groups_buffer: Dict[int, List[EvaluationRow]] = defaultdict(list) # buffer for group results. only flush to output buffer when a whole group is ready + + self.background_tasks = set() # run evaluations in the background asynchronously + + self.rollout_n = rollout_n + self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n + self.evaluation_test_kwargs = evaluation_test_kwargs + + async def schedule_dataset( + self, + dataset: List[EvaluationRow], + base_config: RolloutProcessorConfig, + ): + """ + Populates the queue with initial tasks (the first micro-batch for each sample). + """ + for i, row in enumerate(dataset): + # Calculate ranges for the first in-group minibatch + batch_start = 0 + batch_end = min(self.in_group_minibatch_size, self.rollout_n) + run_indices = list(range(batch_start, batch_end)) + + # Initial priority: Low (1), ordered by dataset index + priority = (1, i) + + task = RolloutTask( + priority=priority, + row=row, + run_indices=run_indices, + config=base_config, + row_index=i, + history=[] # Initial batch has no history + ) + self.queue.put_nowait(task) + + async def worker(self): + """ + Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any). + """ + while True: + # Get a task from the priority queue + task: RolloutTask = await self.queue.get() + + try: + await self._process_task(task) + except Exception as e: + logging.error(f"Error processing task for row {task.row.input_metadata.row_id}: {e}", exc_info=True) + finally: + self.queue.task_done() + + async def _process_task(self, task: RolloutTask): + """ + Executes a single micro-batch task. + """ + async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): + """Background evaluation task.""" + rollout_id = rows_to_eval[0].execution_metadata.rollout_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.rollout_id + experiment_id = rows_to_eval[0].execution_metadata.experiment_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.experiment_id + run_id = rows_to_eval[0].execution_metadata.run_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.run_id + eval_res = None + + async with self.eval_sem: + async with rollout_logging_context( + rollout_id or "", + experiment_id=experiment_id, + run_id=run_id, + ): + if isinstance(rows_to_eval, list): + eval_res = await execute_pytest_with_exception_handling( + test_func=self.eval_executor, + evaluation_test_kwargs=self.evaluation_test_kwargs, + processed_dataset=rows_to_eval, + ) + else: + eval_res = await execute_pytest_with_exception_handling( + test_func=self.eval_executor, + evaluation_test_kwargs=self.evaluation_test_kwargs, + processed_row=rows_to_eval, + ) + + # push result to the output buffer + if self.output_buffer: + if isinstance(eval_res, list): + for row in eval_res: + self._post_process_result(row) + await self.output_buffer.add_result(row) + else: + self._post_process_result(eval_res) + await self.output_buffer.add_result(eval_res) + + if isinstance(eval_res, list): + self.results.extend(eval_res) + else: + self.results.append(eval_res) + return eval_res + + # 1. Prepare Config & Row for this micro-batch + current_batch_rows = [] + for run_idx in task.run_indices: + row_copy = task.row.model_copy(deep=True) + + row_copy.execution_metadata.run_id = generate_id() + row_copy.execution_metadata.rollout_id = generate_id() + if row_copy.execution_metadata.extra is None: + row_copy.execution_metadata.extra = {} + row_copy.execution_metadata.extra["run_index"] = run_idx + + # Inject Speculation History + if ENABLE_SPECULATION and task.history: + cp = row_copy.input_metadata.completion_params + max_tokens = cp.get("max_tokens", 2048) + # Ensure safe dict access + if not isinstance(cp, dict): + cp = {} + # Need to check and initialize nested dicts + extra_body = cp.get("extra_body") + if extra_body is None or not isinstance(extra_body, dict): + extra_body = {} + # for speculation, see + # https://docs.fireworks.ai/guides/predicted-outputs + # https://platform.openai.com/docs/guides/predicted-outputs?lang=python + extra_body["prediction"] = {"type": "content", "content": " ".join(task.history)[:max_tokens]} + cp["extra_body"] = extra_body + row_copy.input_metadata.completion_params = cp + + current_batch_rows.append((run_idx, row_copy)) + self.active_logger.log(row_copy) + + + # 2. Execute Rollout + batch_results: List[EvaluationRow] = [] + if current_batch_rows: + for idx, row in current_batch_rows: + async for result_row in rollout_processor_with_retry( + self.rollout_processor, [row], task.config, idx + ): + batch_results.append(result_row) + # in pointwise, we start evaluation immediately + if self.mode == "pointwise": + t = asyncio.create_task(_run_eval(result_row)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) + + # 3. Evaluate and Collect History + current_batch_history_updates = [] + # Extract history from rollout results (assuming eval doesn't change content needed for history) + for res in batch_results: + last_msg = res.last_assistant_message() + if last_msg and last_msg.content: + content = last_msg.content + current_batch_history_updates.append(str(content)) + else: + current_batch_history_updates.append("") + + # in groupwise, we send all rows to evaluator in one go when the whole group is complete + if self.mode == "groupwise": + self.groups_buffer[task.row_index].extend(batch_results) + if len(self.groups_buffer[task.row_index]) >= self.rollout_n: + full_group = self.groups_buffer.pop(task.row_index) + t = asyncio.create_task(_run_eval(full_group)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) + + # 4. Schedule Next Micro-batch (High Priority) + last_run_idx = task.run_indices[-1] if task.run_indices else -1 + next_start = last_run_idx + 1 + + if next_start < self.rollout_n: + next_end = min(next_start + self.in_group_minibatch_size, self.rollout_n) + next_indices = list(range(next_start, next_end)) + new_history = task.history + current_batch_history_updates + + # Priority 0 (High) to ensure we finish this sample ASAP + new_priority = (0, task.row_index) + + new_task = RolloutTask( + priority=new_priority, + row=task.row, + run_indices=next_indices, + config=task.config, + row_index=task.row_index, + history=new_history + ) + self.queue.put_nowait(new_task) + + def _post_process_result(self, res: EvaluationRow): + """ + Process evaluation result: update cost metrics, status, and log. + """ + add_cost_metrics(res) + if res.eval_metadata is not None: + if res.rollout_status.is_error(): + res.eval_metadata.status = Status.error( + res.rollout_status.message, res.rollout_status.details + ) + elif not ( + res.eval_metadata.status and res.eval_metadata.status.code != Status.Code.RUNNING + ): + res.eval_metadata.status = Status.eval_finished() + + if os.getenv("EP_DEBUG_SERIALIZATION", "0").strip() == "1": + try: + preview = [ + { + "role": m.role, + "len": len(m.content or "") if isinstance(m.content, str) else None, + "tool_calls": len(m.tool_calls or []) + if hasattr(m, "tool_calls") and isinstance(m.tool_calls, list) + else 0, + "tool_call_id": getattr(m, "tool_call_id", None), + "name": getattr(m, "name", None), + } + for m in res.messages + ] + print("[EP-Log] Row messages:", preview) + except Exception: + pass + self.active_logger.log(res) + + async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: RolloutProcessorConfig): + self.num_runs = num_runs + + # 1. Schedule initial tasks + await self.schedule_dataset(dataset, base_config) + + # 2. Start Workers + # If we have separate limits, we need enough workers to saturate both stages + num_workers = self.max_concurrent_rollouts + + workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] + + # 3. Wait for completion + await self.queue.join() + + # Wait for background evaluations to finish + if self.background_tasks: + await asyncio.gather(*self.background_tasks, return_exceptions=True) + + # 4. Cleanup + for w in workers: + w.cancel() + + if workers: + await asyncio.gather(*workers, return_exceptions=True) + + # Return collected results + return self.results + +async def execute_priority_rollouts( + dataset: List[EvaluationRow], + num_runs: int, + rollout_processor: RolloutProcessor, + config: RolloutProcessorConfig, + max_concurrent_rollouts: int, + active_logger: DatasetLogger, + eval_executor: TestFunction, + max_concurrent_evaluations: int = 96, + mode: str = "pointwise", + micro_batch_data_buffer: Optional[MicroBatchDataBuffer] = None, + evaluation_test_kwargs: Dict[str, Any] = {}, +): + scheduler = PriorityRolloutScheduler( + rollout_processor=rollout_processor, + max_concurrent_rollouts=max_concurrent_rollouts, + active_logger=active_logger, + eval_executor=eval_executor, + output_buffer=micro_batch_data_buffer, + max_concurrent_evaluations=max_concurrent_evaluations, + rollout_n=num_runs, + mode=mode, + in_group_minibatch_size=(num_runs // 2), + evaluation_test_kwargs=evaluation_test_kwargs, + ) + return await scheduler.run(dataset, num_runs, config) diff --git a/eval_protocol/pytest/validate_signature.py b/eval_protocol/pytest/validate_signature.py index 500de649..fa4cc752 100644 --- a/eval_protocol/pytest/validate_signature.py +++ b/eval_protocol/pytest/validate_signature.py @@ -53,8 +53,6 @@ def validate_signature( # validate that the function has a return type of List[EvaluationRow] if not _is_list_of_evaluation_row(signature.return_annotation): # pyright: ignore[reportAny] raise ValueError("In groupwise mode, your eval function must return a list of EvaluationRow instances") - if completion_params is not None and len(completion_params) < 2: - raise ValueError("In groupwise mode, you must provide at least 2 completion parameters") else: # all mode: function should accept input_dataset and model if "rows" not in signature.parameters: diff --git a/tests/pytest/test_rollout_scheduler.py b/tests/pytest/test_rollout_scheduler.py new file mode 100644 index 00000000..1a1ff7a9 --- /dev/null +++ b/tests/pytest/test_rollout_scheduler.py @@ -0,0 +1,49 @@ +from eval_protocol.pytest import evaluation_test, SingleTurnRolloutProcessor +from eval_protocol.models import EvaluationRow, Message, EvaluateResult, InputMetadata +from typing import List + +@evaluation_test( + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + input_rows=[ + [ + EvaluationRow( + messages=[ + Message(role="system", content=f"You are a helpful assistant, and this is row {i}"), + Message(role="user", content="What is the capital of France?"), + ], + input_metadata=InputMetadata(row_id=f"row-{i}"), + ) + for i in range(10) + ] + ], + rollout_processor=SingleTurnRolloutProcessor(), + num_runs=4, + mode="pointwise", +) +def test_rollout_scheduler(row: EvaluationRow) -> EvaluationRow: + row.evaluation_result = EvaluateResult(score=0.5, reason="Dummy evaluation result") + return row + + +@evaluation_test( + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + input_rows=[ + [ + EvaluationRow( + messages=[ + Message(role="system", content=f"You are a helpful assistant, and this is row {i}"), + Message(role="user", content="What is the capital of France?"), + ], + input_metadata=InputMetadata(row_id=f"row-{i}"), + ) + for i in range(10) + ] + ], + rollout_processor=SingleTurnRolloutProcessor(), + num_runs=4, + mode="groupwise", +) +def test_rollout_scheduler_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]: + for i,row in enumerate(rows): + row.evaluation_result = EvaluateResult(score=0.1 * i, reason="Dummy evaluation result") + return rows \ No newline at end of file diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py new file mode 100644 index 00000000..27e748eb --- /dev/null +++ b/tests/test_priority_scheduler.py @@ -0,0 +1,322 @@ +import pytest +import asyncio +import time +from unittest.mock import MagicMock, AsyncMock, patch +from typing import List, Union + +from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, EvaluateResult +from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask +from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger + +# Mock models +def create_mock_row(row_id: str = "test-row") -> EvaluationRow: + return EvaluationRow( + input_metadata=InputMetadata( + row_id=row_id, + completion_params={"model": "test-model"} + ), + execution_metadata=ExecutionMetadata() + ) + +@pytest.fixture +def mock_rollout_processor(): + processor = MagicMock() + # Mocking the rollout to be an async generator + async def mock_rollout_gen(rows, config, run_idx): + for row in rows: + # Simulate some work + yield row + processor.side_effect = mock_rollout_gen + return processor + +@pytest.fixture +def mock_logger(): + return MagicMock(spec=DatasetLogger) + +@pytest.fixture +def mock_eval_executor(): + return AsyncMock() + +@pytest.fixture +def base_config(): + return RolloutProcessorConfig( + completion_params={"model": "test-model"}, + mcp_config_path="test_config.yaml", + semaphore=asyncio.Semaphore(10), + steps=10 + ) + +@pytest.mark.asyncio +async def test_scheduler_basic_execution( + mock_logger, mock_eval_executor, base_config +): + """Test that the scheduler processes all rows and completes.""" + dataset = [create_mock_row(f"row-{i}") for i in range(5)] + num_runs = 2 + micro_batch_size = 1 + + # Mock rollout processor with delay + async def delayed_rollout(processor, rows, config, run_idx): + await asyncio.sleep(0.01) + for row in rows: + yield row + + async def mock_eval(row): + row.evaluation_result = EvaluateResult(score=1.0, is_score_valid=True) + return row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=delayed_rollout): + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=2, + active_logger=mock_logger, + eval_executor=mock_eval, + max_concurrent_evaluations=2, + rollout_n=num_runs, + in_group_minibatch_size=micro_batch_size + ) + + results = await scheduler.run(dataset, num_runs, base_config) + + assert len(results) == 5 * num_runs + for res in results: + assert res.evaluation_result is not None + assert res.evaluation_result.score == 1.0 + + +@pytest.mark.asyncio +async def test_concurrency_control( + mock_logger, mock_eval_executor, base_config +): + """ + Verify that max_concurrent_rollouts and max_concurrent_evaluations are respected. + """ + dataset = [create_mock_row(f"row-{i}") for i in range(10)] + num_runs = 1 + micro_batch_size = 1 + + max_rollouts = 4 + max_evals = 2 + + active_rollouts = 0 + max_active_rollouts_seen = 0 + + active_evals = 0 + max_active_evals_seen = 0 + + rollout_lock = asyncio.Lock() + eval_lock = asyncio.Lock() + + async def mock_rollout_gen(processor, rows, config, run_idx): + nonlocal active_rollouts, max_active_rollouts_seen + async with rollout_lock: + active_rollouts += 1 + max_active_rollouts_seen = max(max_active_rollouts_seen, active_rollouts) + + # Simulate slow rollout + await asyncio.sleep(0.05) + + for row in rows: + yield row + + async with rollout_lock: + active_rollouts -= 1 + + # Use a real async function for eval to work with execute_pytest properly + async def mock_eval(row): + nonlocal active_evals, max_active_evals_seen + async with eval_lock: + active_evals += 1 + max_active_evals_seen = max(max_active_evals_seen, active_evals) + + # Simulate evaluation + await asyncio.sleep(0.05) + + async with eval_lock: + active_evals -= 1 + return row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + + # Mock processor instance (can be anything since we patched the wrapper) + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=max_rollouts, + active_logger=mock_logger, + eval_executor=mock_eval, + max_concurrent_evaluations=max_evals, + rollout_n=num_runs, + in_group_minibatch_size=micro_batch_size + ) + + await scheduler.run(dataset, num_runs, base_config) + + # Verify limits were respected + assert max_active_rollouts_seen <= max_rollouts, f"Rollout concurrency exceeded: {max_active_rollouts_seen} > {max_rollouts}" + assert max_active_evals_seen <= max_evals, f"Eval concurrency exceeded: {max_active_evals_seen} > {max_evals}" + + # Verify everything ran + # 10 rows * 1 run = 10 results + assert len(scheduler.results) == 10 + +@pytest.mark.asyncio +async def test_priority_scheduling( + mock_logger, mock_eval_executor, base_config +): + """ + Test that subsequent micro-batches are prioritized. + """ + dataset = [create_mock_row(f"row-{i}") for i in range(2)] + num_runs = 2 + micro_batch_size = 1 + + execution_order = [] + + async def mock_rollout_gen(processor, rows, config, run_idx): + row_id = rows[0].input_metadata.row_id + execution_order.append(f"{row_id}_run_{run_idx}") + for row in rows: + yield row + + async def mock_eval(row): + return row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=1, # Force serial execution to test priority + active_logger=mock_logger, + eval_executor=mock_eval, + max_concurrent_evaluations=1, + rollout_n=num_runs, + in_group_minibatch_size=micro_batch_size + ) + + await scheduler.run(dataset, num_runs, base_config) + + # Expected order: row-0_run_0, row-0_run_1, row-1_run_0, row-1_run_1 + # Note: Since row-0_run_0 finishes, it schedules row-0_run_1 with HIGH priority (0). + # row-1_run_0 is in queue with LOW priority (1). + # So row-0_run_1 should run before row-1_run_0. + expected = [ + "row-0_run_0", + "row-0_run_1", + "row-1_run_0", + "row-1_run_1" + ] + + assert execution_order == expected + +@pytest.mark.asyncio +async def test_worker_scaling( + mock_logger, mock_eval_executor, base_config +): + """ + Test that the number of workers scales with the sum of limits. + """ + dataset = [create_mock_row("row-0")] + max_rollouts = 5 + max_evals = 3 + # Updated expectation: workers only scale with rollout concurrency now + expected_workers = max_rollouts + + worker_start_count = 0 + + class InstrumentedScheduler(PriorityRolloutScheduler): + async def worker(self): + nonlocal worker_start_count + worker_start_count += 1 + try: + await self.queue.get() + self.queue.task_done() + except asyncio.CancelledError: + pass + except Exception: + pass + + async def schedule_dataset(self, *args): + # Put enough items to ensure all workers wake up and grab one + for i in range(expected_workers): + task = RolloutTask( + priority=(1, i), + row=dataset[0], + run_indices=[], + config=base_config, + row_index=0, + history=[] + ) + await self.queue.put(task) + + processor_instance = MagicMock() + scheduler = InstrumentedScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=max_rollouts, + active_logger=mock_logger, + eval_executor=mock_eval_executor, + max_concurrent_evaluations=max_evals, + rollout_n=1, + in_group_minibatch_size=1 + ) + + await scheduler.run(dataset, 1, base_config) + + assert worker_start_count == expected_workers + +@pytest.mark.asyncio +async def test_groupwise_mode( + mock_logger, mock_eval_executor, base_config +): + """ + Test that groupwise mode collects all runs before evaluating. + """ + dataset = [create_mock_row("row-0")] + num_runs = 4 + micro_batch_size = 2 + + # We expect 2 batches of 2 runs each. + # Batch 1 (Runs 0,1): Should buffer and update history, NOT call eval. + # Batch 2 (Runs 2,3): Should buffer, update history, AND call eval with all 4 runs. + + eval_calls = [] + + async def mock_eval(rows): + eval_calls.append(rows) + return rows # Pass through + + async def mock_rollout_gen(processor, rows, config, run_idx): + for row in rows: + yield row + + with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen): + processor_instance = MagicMock() + + scheduler = PriorityRolloutScheduler( + rollout_processor=processor_instance, + max_concurrent_rollouts=1, + active_logger=mock_logger, + eval_executor=mock_eval, + max_concurrent_evaluations=1, + mode="groupwise", + rollout_n=num_runs, + in_group_minibatch_size=micro_batch_size + ) + + results = await scheduler.run(dataset, num_runs, base_config) + + # Verify evaluation was called EXACTLY ONCE + assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}" + + # Verify it was called with ALL 4 rows + evaluated_rows = eval_calls[0] + assert len(evaluated_rows) == 4, f"Expected 4 rows in group eval, got {len(evaluated_rows)}" + + # Verify results contains all 4 runs (returned from eval) + # Note: eval returns a list of 4 rows. scheduler.results extends this list. + assert len(results) == 4