Skip to content

Pipeline training support + priority based rollout scheduler#358

Merged
mayinghan merged 11 commits intomainfrom
pipeline-training-support
Dec 9, 2025
Merged

Pipeline training support + priority based rollout scheduler#358
mayinghan merged 11 commits intomainfrom
pipeline-training-support

Conversation

@mayinghan
Copy link
Collaborator

@mayinghan mayinghan commented Dec 5, 2025

This PR added support for 1) priority based rollout scheduler and 2) Micro Batch Output Data Buffer.

example run script:
rm -rf test2/ && PYTHONPATH=./ EP_MICRO_BATCH_OUTPUT_SIZE=2 PYTHONPATH=./ EP_USE_PRIORITY_SCHEDULER=1 EP_NO_UPLOAD=1 python -m pytest -sv tests/pytest/test_rollout_scheduler.py::test_rollout_scheduler --ep-output-dir ./test2

Output dir structure be like:
Screenshot 2025-12-05 at 3 29 39 PM

Priority based rollout scheduler with rewrite speculation support.

  • Rollout tasks are scheduled in a way that rollouts from the same group will be executed at the same time.
  • A new parameter in_group_microbatch_size is added to support rollout K samples (K < rollout_n) every time and feed their responses in the "prediction" field or the later mini batches to support rewrite speculation to accelerate the rollout.
  • Evaluation function is run in a non blocking way in background task pool.
  • This can help achieve
    • better kv cache hit rate
    • compability of utilizing rewrite speculation to speed up inference.

MicroBatchDataBuffer

  • Once a full group of rollout_n samples are ready, they will be push to the output buffer.
  • The output buffer will write to disk once the microbatch size condition meets, the result will be directly consumed by trainer to kick off a training step.

Note

Adds a priority-based rollout scheduler (with optional speculation) and a micro-batch output buffer, wired into evaluation_test via env flags and covered by new tests.

  • Core (pytest):
    • Priority Scheduler: New eval_protocol/pytest/priority_scheduler.py
      • PriorityRolloutScheduler with priority-queued micro-batches, per-row run batching, background eval execution, and optional rewrite speculation (via ENABLE_SPECULATION).
      • execute_priority_rollouts(...) entrypoint; respects max_concurrent_rollouts/max_concurrent_evaluations.
    • Micro-batch Buffer: New eval_protocol/pytest/buffer.py
      • MicroBatchDataBuffer buffers per-sample results until all num_runs complete; flushes JSONL batches to disk (output_path_template).
    • Integration: evaluation_test.py
      • Optional scheduler path gated by EP_USE_PRIORITY_SCHEDULER (disabled for MCPGymRolloutProcessor).
      • Optional buffering via EP_MICRO_BATCH_OUTPUT_SIZE and EP_OUTPUT_DIR (auto-close on completion).
      • Aggregates run_index from execution_metadata.extra to populate all_results; invokes postprocess accordingly.
  • Validation:
    • validate_signature.py: remove groupwise requirement for at least 2 completion_params.
  • Tests:
    • tests/test_priority_scheduler.py: unit tests for basic execution, concurrency limits, priority ordering, worker scaling, and groupwise behavior.
    • tests/pytest/test_rollout_scheduler.py: pytest-style tests for pointwise and groupwise modes using the new scheduler.

Written by Cursor Bugbot for commit b29af96. This will update automatically on new commits. Configure here.

@mayinghan mayinghan marked this pull request as ready for review December 5, 2025 07:47
await self.output_buffer.add_result(row)
else:
self._post_process_result(eval_res)
await self.output_buffer.add_result(eval_res)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Results skip post-processing when output buffer disabled

The _post_process_result method is only called inside the if self.output_buffer: conditional block. When the mini batch buffer feature is not configured (i.e., output_buffer is None), results are appended to self.results but _post_process_result is never invoked. This skips critical operations: add_cost_metrics() is not called, eval_metadata.status remains stuck at "RUNNING" instead of being updated to finished/error, and results are never logged via active_logger.log().

Fix in Cursor Fix in Web

full_group = self.groups_buffer.pop(task.row_index)
t = asyncio.create_task(_run_eval(full_group))
self.background_tasks.add(t)
t.add_done_callback(self.background_tasks.discard)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Priority scheduler missing "all" mode evaluation handling

The _process_task method only handles mode == "pointwise" (line 216) and mode == "groupwise" (line 233) for triggering evaluations. However, EvaluationTestMode includes "all" as a valid mode. When mode="all" is used with the priority scheduler enabled, rollouts complete but evaluations are never triggered, silently skipping the evaluation step entirely.

Fix in Cursor Fix in Web

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove all mode, its not that useful cc @morgendave

"""
Represents a single unit of work for the worker pool.
Priority tuple structure: (status, row_index)
- status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later on based on strategy we could change some of the priorities


# Concurrency Control
self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts)
self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollout_sem and eval_sem duplicates? rollout_sem not used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mb forgot to delete it. there is a global semiphore used in rolloutprocessor, so this one is not needed in the current design

run_id = rows_to_eval[0].execution_metadata.run_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.run_id
eval_res = None

async with self.eval_sem:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If eval_sem max is less than rollout concurrency this might be blocked and timeout?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but this is needed if user is using another env or llm as the judge where there is a qps limit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any retry at this level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in this pr, the status quo is we dont have default retry on user's evaluator function. we can discuss if we want to have it with @xzrderek in a seaprate thread i think

finally:
self.queue.task_done()

async def _process_task(self, task: RolloutTask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can treat this as first step but this is not making scheduling waste go away.
Currently this would do

  1. inference call
  2. eval

Step 1 for multi-turns is wasting, we might need some feedbacks to make concurrency changed
Step2 is blocking step1 from finishing/getting new rollouts started, which means we should optimize it first

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline, eval is running in background tasks, so it wont be blocked

batch_results.append(result_row)
# in pointwise, we start evaluation immediately
if self.mode == "pointwise":
t = asyncio.create_task(_run_eval(result_row))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @morgendave eval will be executed in a background task pool (actually there is not a pool, just submitted them as background tasks) so the inference wont be blocked as before

@mayinghan mayinghan requested a review from morgendave December 9, 2025 06:10
@mayinghan mayinghan merged commit 8219c44 into main Dec 9, 2025
15 of 16 checks passed
@mayinghan mayinghan deleted the pipeline-training-support branch December 9, 2025 23:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants