Skip to content

fix(replay): draw -1 padding fillers from per-row complement set#11

Merged
DavidBellamy merged 2 commits intodeployfrom
fix/r3-padding-cycle-collision
Apr 28, 2026
Merged

fix(replay): draw -1 padding fillers from per-row complement set#11
DavidBellamy merged 2 commits intodeployfrom
fix/r3-padding-cycle-collision

Conversation

@DavidBellamy
Copy link
Copy Markdown
Collaborator

Problem

miles/utils/replay_base.py::BaseReplayManager.get_topk_fn._get_replay_result fills -1 slots in top_indices (rollout-engine truncated/aborted token routings) by walking a flat arange mod num_experts:

padding_mask = top_indices == -1
if padding_mask.any():
    top_indices[padding_mask] = (
        torch.arange(padding_mask.sum(), device=top_indices.device, dtype=top_indices.dtype)
        % scores.shape[1]
    )

This ignores the [num_tokens, topk] row structure. Within a single token's row, the cyclic filler can land on an expert id already present in that row's existing topk picks, producing a within-row duplicate.

Downstream the router builds routing_map via scatter(top_indices, True) into a [num_tokens, num_experts] boolean matrix. Duplicates within a row collapse silently. As a result routing_map.sum() < num_tokens * topk. The MoE token dispatcher then computes input_splits = num_local_tokens_per_expert.reshape(ep_size, num_local_experts).sum(axis=1) — i.e. sum(input_splits) = routing_map.sum() < num_tokens * topk. But num_out_tokens for the permuted buffer is routing_map.size(0) * topk (static, dropless path). At dispatch time, all_to_all_single is called with permuted_tokens.shape[0] = num_tokens * topk and input_splits summing to less, so it raises:

RuntimeError: Split sizes doesn't match total dim 0 size

This is the canonical R3 crash described in radixark#1002 — and reproduced repeatedly by Phase-3 RL360 jobs.

Why it shows up intermittently

The probability of collision per row depends on:

  1. The truncation pattern at the rollout engine (which rows have any -1s, and how many).
  2. The topk / num_experts ratio. For GLM-4.7-Flash (topk=4, num_experts=64) collisions are relatively rare; for higher topk:experts ratios (or edge configs) they are common.

This explains why the same code path can run many MoE forwards cleanly on one config and crash on another.

Fix

Replace the cyclic filler with a per-row complement-set draw: for each row, fill its -1 slots with the highest-scoring experts NOT already used in that row's existing topk picks. By construction, no within-row duplicate is ever produced.

n_experts = scores.shape[1]
non_pad = ~padding_mask
# [n_tokens, n_experts] used-mask, OR-aggregated from non-pad picks
used_mask = (
    F.one_hot(top_indices.clamp(min=0).long(), num_classes=n_experts)
    .mul(non_pad.long().unsqueeze(-1))
    .sum(dim=1)
    .bool()
)
# rank free experts per row by score; used experts sink to the tail via -inf
masked_scores = scores.masked_fill(used_mask, float("-inf"))
_, sorted_free = masked_scores.sort(dim=1, descending=True)
# k-th -1 in each row gets sorted_free[row, k]
pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1
fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(top_indices.dtype)
top_indices = torch.where(padding_mask, fill_values, top_indices)

Notes:

  • top_indices is int32 per Sample.rollout_routed_experts; F.one_hot requires int64, hence .long().
  • pad_cumsum.clamp(min=0) is safe — for non-pad positions torch.where masks the gathered value.
  • Score-rank ordering picks "most plausible" fillers (the experts the model was next-most-likely to route to anyway), keeping replay close to a counterfactual on-policy roll.

Validation

End-to-end iter-scale RL run on M2 (job 1581417, GLM-4.7-Flash, topk=4, num_experts=64, expert-model-parallel-size=8, real terminus-2 agent, R3 on):

Signal Result
dup_rows_after_pad (any > 0?) 0 across 620 R3_DIAG prints
Dispatcher gap = num_tokens*topk − routing_map.sum() 0 across all dispatcher prints
Split sizes doesn't match total dim 0 size 0
Timer train_step cycles completed 6 in 1h6min
Step-record log lines 144

The job eventually exited from an unrelated NCCL ALLREDUCE timeout in log_rollout_data → gather_log_data (wandb gather), well after R3 had repeatedly run cleanly through the dispatch path.

The diagnostic that confirmed root cause is preserved as an R3_DIAG env-gated print in the same function on a separate working branch; it was not included here to keep this PR scoped to the fix.

Closes radixark#1002.

github-actions Bot and others added 2 commits April 28, 2026 16:31
…=233f489f32c0806d4a7fccf9ea42d668a734ee28,fix/allow-pd-worker-type-on-miles-router=f0c9d3cc1f9f9a9ed98723e9462f8e1a3465a428,fix/check-reward-nonzero-std-none-guard=8998a5f9231c36490be39bc7d9cd957444a94f80,fix/guard-round-none-in-zero-std-metrics=7b7efa922b47f3cb2b3d1e1b6c549ae6d7aa640e,fix/metrics-interpretability=6ddf57a8d341c42735ee62d97a7bf2693eba5ba0,fix/propagate-pythonpath-to-ray-remote-actors=779839cb2569f411a373dac9a229b469aa0aa991,fix/rollback-error-recovery=dd188aaddb887e5c1de066277e7c9f10eec297b0,fix/rollback-reseed-empty-assistant=f5da99e391dcc26327d4abb2aea2f2572863c690,fix/session-auto-create=29a0dcad45e36985c7af6f7c1dcf653e0eace4f2,fix/session-server-strip-stale-content-length-clean=9a0ef97613294d854b60e98e2202f58a7734936f,fix/tito-assistant-append-role=1f58c11e56fd1ccb0bcf503c5dc63bc914c4ed31,fix/tito-v1-model-info-stub-v2=9ee914087ca5cc85db157105c90298085d112953,fix/truncate-routed-experts=25645357a7d51daa307a9fa25081095bc3cfb2a1,fix/wandb-shared-mode-online-timeout=28cb49c4a4245cc5bb6a00fb2e976bddfc7a35fd,fix/move-base-port-above-mooncake-rpc=70ea80f905b39a45cfb22faece608fa35f505e11,fix/session-server-strip-stale-content-length-clean=9a0ef97613294d854b60e98e2202f58a7734936f,fix/tito-allow-assistant-append=c15c70487c93a69248f327550399f14890ecf13c,fix/tito-assistant-append-role=1f58c11e56fd1ccb0bcf503c5dc63bc914c4ed31
The previous routing-replay padding-replacement code:

    top_indices[padding_mask] = (
        torch.arange(padding_mask.sum(), ...) % scores.shape[1]
    )

walks a flat arange mod num_experts to fill -1 slots, ignoring the row
structure. For any row with one or more -1s, the cyclic filler can land
on an expert id that is already present in that same row's existing
topk picks, producing within-row duplicates.

Downstream the router converts top_indices into a [num_tokens, num_experts]
routing_map via one-hot scatter, where duplicates within a row silently
collapse. As a result routing_map.sum() < num_tokens * topk. The
MoEAlltoAllTokenDispatcher then computes input_splits from
routing_map.sum(dim=0) but uses num_out_tokens = num_tokens * topk for
the permuted buffer, so sum(input_splits) < permuted_tokens.shape[0].
The subsequent all_to_all_single call raises:

    RuntimeError: Split sizes doesn't match total dim 0 size

This bug is intermittent: it depends on (a) which rows have any -1s,
which is a function of rollout-engine truncation/abort luck, and (b)
the topk / num_experts ratio (collisions are likelier when topk
approaches num_experts).

This change replaces the cyclic filler with a per-row complement-set
draw: for each row, pick the highest-scoring experts NOT already used
in that row, deterministically distinct, in score-rank order. By
construction, no within-row duplicate is ever produced.

Closes radixark#1002.
@DavidBellamy DavidBellamy force-pushed the fix/r3-padding-cycle-collision branch from 5318ac5 to 3c88d7d Compare April 28, 2026 16:32
@DavidBellamy DavidBellamy merged commit 87c9977 into deploy Apr 28, 2026
11 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant