Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions megatron/rl/agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,22 @@ class RolloutRequest(Request):


class GroupedRolloutRequest(Request):
"""Request to agent to generate grouped Rollouts."""
"""Request to agent to generate grouped Rollouts.

Attributes:
num_groups: Number of rollout groups to generate per batch.
rollouts_per_group: Number of rollouts within each group.
inference_interface: Interface used for model inference during rollout generation.
validation: Whether this is a validation (not training) request.
filter_groups_with_same_reward: Drop groups where all rollouts have identical rewards.
streaming: If True, generate groups indefinitely until the consumer stops iterating.
If False, generate exactly `num_groups` groups and stop.
enforce_order: If True, yield groups in staleness-preserving order for forced-lag
training. In the steady state, complete batches are yielded in strict sequential order
so that the consumer always trains on the oldest available rollouts first.
During warmup (the first `num_workers` batches), all batches are equally-stale,
so this parameter is ignored and groups are yielded immediately as they complete.
"""

num_groups: int
rollouts_per_group: int
Expand Down Expand Up @@ -271,24 +286,36 @@ async def shutdown_queue_when_done():
shutdown_task = asyncio.create_task(shutdown_queue_when_done())

try:
next_batch_id = 0
# Forced lag involves strict ordering at steady-state.
# However, the initial conditions do not require (and are harmed by) strict ordering.
warmup_groups_until_release = groups_per_worker
next_batch_id = num_workers
pending: dict[int, GroupedRollouts] = {}
while True:
try:
group = await grouped_rollouts.get()
except asyncio_QueueShutDown:
break
if request.enforce_order:
# Accumulate groups and enforce submission order across batches.
pending.setdefault(group.batch_id, []).append(group)
while (l := len(pending.get(next_batch_id, []))) >= groups_per_worker:
assert l == groups_per_worker
batch = pending.pop(next_batch_id)
batch.sort(key=lambda g: g.index_in_batch)
next_batch_id += 1
for g in batch:
yield g
submission_gate.release()
if group.batch_id < num_workers:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this change the semantics of enforce_order? So, now enforce_order preserves the staleness order, not the submission order?

Regardless of the answer, we need to have a docstring for GroupedRolloutRequest to describe what this actually means, right now it is unclear.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of the answer, we need to have a docstring for GroupedRolloutRequest to describe what this actually means, right now it is unclear.

Done! I believe my docstring also reflects the discussion below; let me know if not.

Doesn't this change the semantics of enforce_order? So, now enforce_order preserves the staleness order, not the submission order?

It changes the semantics of enforce_order, so that enforce_order is completely ignored during warm-up (the first num_workers batches).

The goal of enforce_order is to guarantee that each training batch is composed of rollouts of a specific, pre-determined, staleness. During "warmup", all rollouts consumed by training were started on iteration 0; this is just how warmup needs to work. Because all rollouts have equal staleness, it makes no sense to preserve their submission order; we get no benefit out of it, only drawbacks.

So this is an optimization at warmup, to make it smoother. And because forced lag really struggles to smooth out its per-step behavior after warmup (unlike the smooth behavior you are used to seeing from unforced lag, where it's only the first few steps that are bumpy, with forced lag the bumpyness of the first few steps causes the entire run to show a saw-tooth pattern on throughput numbers), this optimization makes the entire run slightly smoother.

It is just a free win. The only drawback is that other RL projects that exist in the open-source community are not taking this free win, so our forced lag simulations will be a steel-man of how forced lag is actually implemented by others.

# Warmup: initial batches all have equal staleness;
# yield immediately without waiting for batch completion.
yield group
warmup_groups_until_release -= 1
if warmup_groups_until_release == 0:
submission_gate.release()
warmup_groups_until_release = groups_per_worker
else:
# Steady state: accumulate and enforce strict batch order.
pending.setdefault(group.batch_id, []).append(group)
while (l := len(pending.get(next_batch_id, []))) >= groups_per_worker:
assert l == groups_per_worker
batch = pending.pop(next_batch_id)
batch.sort(key=lambda g: g.index_in_batch)
next_batch_id += 1
for g in batch:
yield g
submission_gate.release()
else:
# Yield groups as soon as they're completed.
yield group
Expand Down
24 changes: 22 additions & 2 deletions tests/unit_tests/rl/test_grouped_rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def group_rollout(self, request):
idx = self._call_count
self._call_count += 1
if idx < self.num_slow_calls:
await asyncio.sleep(0.03)
await asyncio.Event().wait() # Block forever; cancelled when test completes
return [
Rollout(
trajectory=[f"t{idx}"],
Expand All @@ -54,7 +54,8 @@ class TestGroupedRollouts:
[
pytest.param(0, False, 8, 8, None, id="non_batched"),
pytest.param(0, False, 4, 4, None, id="non_streaming_fewer_than_parallel"),
pytest.param(4, True, 2, 8, [0, 0, 1, 1, 2, 2, 3, 3], id="batched_submission_order"),
pytest.param(4, True, 2, 8, None, id="streaming_batched"),
pytest.param(0, True, 2, 16, None, id="streaming_steady_state_order"),
pytest.param(0, True, 1, 10, None, id="streaming"),
],
)
Expand All @@ -78,6 +79,25 @@ async def test_get_grouped_rollouts(
assert len(groups) == expected_count
if expected_batch_ids is not None:
assert [g.batch_id for g in groups] == expected_batch_ids
if num_slow_calls > 0 and streaming:
# Warmup should not block on slow batches.
batch_ids = [g.batch_id for g in groups]
num_slow_batches = num_slow_calls // num_groups
slow_batches = set(range(num_slow_batches))
assert (
batch_ids[0] not in slow_batches
), f"Expected first group from a fast batch, got batch_id={batch_ids[0]}"
if streaming and num_groups > 1:
# Verify steady-state batches arrive in sequential order.
num_workers = gen.parallel_generation_tasks // num_groups
steady = [g for g in groups if g.batch_id >= num_workers]
if steady:
batch_order = [steady[0].batch_id]
for g in steady[1:]:
if g.batch_id != batch_order[-1]:
batch_order.append(g.batch_id)
expected = list(range(num_workers, num_workers + len(batch_order)))
assert batch_order == expected, f"Steady-state batches out of order: {batch_order}"

@pytest.mark.asyncio
async def test_weighted_multi_task(self):
Expand Down
Loading