From acb6edeb12544dce26f51116627bf89969745caf Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 27 Apr 2026 09:46:36 +0000 Subject: [PATCH 1/2] state:d74d25a62082d55ef5b2461ec07bd5fbe0f69fa8|fix-wandb-step-metrics=233f489f32c0806d4a7fccf9ea42d668a734ee28,fix/allow-pd-worker-type-on-miles-router=f0c9d3cc1f9f9a9ed98723e9462f8e1a3465a428,fix/check-reward-nonzero-std-none-guard=8998a5f9231c36490be39bc7d9cd957444a94f80,fix/guard-round-none-in-zero-std-metrics=7b7efa922b47f3cb2b3d1e1b6c549ae6d7aa640e,fix/metrics-interpretability=6ddf57a8d341c42735ee62d97a7bf2693eba5ba0,fix/propagate-pythonpath-to-ray-remote-actors=779839cb2569f411a373dac9a229b469aa0aa991,fix/rollback-error-recovery=dd188aaddb887e5c1de066277e7c9f10eec297b0,fix/rollback-reseed-empty-assistant=f5da99e391dcc26327d4abb2aea2f2572863c690,fix/session-auto-create=29a0dcad45e36985c7af6f7c1dcf653e0eace4f2,fix/session-server-strip-stale-content-length-clean=9a0ef97613294d854b60e98e2202f58a7734936f,fix/tito-assistant-append-role=1f58c11e56fd1ccb0bcf503c5dc63bc914c4ed31,fix/tito-v1-model-info-stub-v2=9ee914087ca5cc85db157105c90298085d112953,fix/truncate-routed-experts=25645357a7d51daa307a9fa25081095bc3cfb2a1,fix/wandb-shared-mode-online-timeout=28cb49c4a4245cc5bb6a00fb2e976bddfc7a35fd,fix/move-base-port-above-mooncake-rpc=70ea80f905b39a45cfb22faece608fa35f505e11,fix/session-server-strip-stale-content-length-clean=9a0ef97613294d854b60e98e2202f58a7734936f,fix/tito-allow-assistant-append=c15c70487c93a69248f327550399f14890ecf13c,fix/tito-assistant-append-role=1f58c11e56fd1ccb0bcf503c5dc63bc914c4ed31 From 3c88d7d078df0efc8001be20b223d455fbc02e74 Mon Sep 17 00:00:00 2001 From: David Bellamy <12414531+DavidBellamy@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:01:21 +0000 Subject: [PATCH 2/2] fix(replay): draw -1 padding fillers from per-row complement set The previous routing-replay padding-replacement code: top_indices[padding_mask] = ( torch.arange(padding_mask.sum(), ...) % scores.shape[1] ) walks a flat arange mod num_experts to fill -1 slots, ignoring the row structure. For any row with one or more -1s, the cyclic filler can land on an expert id that is already present in that same row's existing topk picks, producing within-row duplicates. Downstream the router converts top_indices into a [num_tokens, num_experts] routing_map via one-hot scatter, where duplicates within a row silently collapse. As a result routing_map.sum() < num_tokens * topk. The MoEAlltoAllTokenDispatcher then computes input_splits from routing_map.sum(dim=0) but uses num_out_tokens = num_tokens * topk for the permuted buffer, so sum(input_splits) < permuted_tokens.shape[0]. The subsequent all_to_all_single call raises: RuntimeError: Split sizes doesn't match total dim 0 size This bug is intermittent: it depends on (a) which rows have any -1s, which is a function of rollout-engine truncation/abort luck, and (b) the topk / num_experts ratio (collisions are likelier when topk approaches num_experts). This change replaces the cyclic filler with a per-row complement-set draw: for each row, pick the highest-scoring experts NOT already used in that row, deterministically distinct, in score-rank order. By construction, no within-row duplicate is ever produced. Closes radixark/miles#1002. --- miles/utils/replay_base.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/miles/utils/replay_base.py b/miles/utils/replay_base.py index 100968a395..8e19b1ba67 100644 --- a/miles/utils/replay_base.py +++ b/miles/utils/replay_base.py @@ -95,10 +95,38 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs): padding_mask = top_indices == -1 if padding_mask.any(): - top_indices[padding_mask] = ( - torch.arange(padding_mask.sum(), device=top_indices.device, dtype=top_indices.dtype) - % scores.shape[1] + # Fill -1 padding slots with experts that are NOT already in the row's + # existing topk picks, ranked by score (most plausible filler first). + # + # The previous implementation `arange(N) % num_experts` could yield a + # filler equal to an expert id already present in the same row, + # producing within-row duplicates. Downstream code builds routing_map + # via scatter(top_indices, True), where duplicates collapse, so + # routing_map.sum() < num_tokens * topk. The MoE token dispatcher then + # computes input_splits from routing_map.sum() but uses + # num_out_tokens = num_tokens * topk for the permuted buffer size, + # so sum(input_splits) < input_tensor.shape[0] and all_to_all_single + # raises "RuntimeError: Split sizes doesn't match total dim 0 size". + # See https://github.com/radixark/miles/issues/1002. + n_experts = scores.shape[1] + non_pad = ~padding_mask + # Build [n_tokens, n_experts] used-mask via one_hot of non-pad picks. + # one_hot requires int64; top_indices is int32 per Sample.rollout_routed_experts. + used_mask = ( + torch.nn.functional.one_hot(top_indices.clamp(min=0).long(), num_classes=n_experts) + .mul(non_pad.long().unsqueeze(-1)) + .sum(dim=1) + .bool() ) + # Per-row score-ranked complement: used experts get -inf so they sink to the tail. + masked_scores = scores.masked_fill(used_mask, float("-inf")) + _, sorted_free = masked_scores.sort(dim=1, descending=True) + # The k-th -1 slot in each row gets sorted_free[row, k]. + pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1 + fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to( + top_indices.dtype + ) + top_indices = torch.where(padding_mask, fill_values, top_indices) if return_probs: return scores.gather(1, top_indices), top_indices