diff --git a/miles/utils/replay_base.py b/miles/utils/replay_base.py index 100968a395..8e19b1ba67 100644 --- a/miles/utils/replay_base.py +++ b/miles/utils/replay_base.py @@ -95,10 +95,38 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs): padding_mask = top_indices == -1 if padding_mask.any(): - top_indices[padding_mask] = ( - torch.arange(padding_mask.sum(), device=top_indices.device, dtype=top_indices.dtype) - % scores.shape[1] + # Fill -1 padding slots with experts that are NOT already in the row's + # existing topk picks, ranked by score (most plausible filler first). + # + # The previous implementation `arange(N) % num_experts` could yield a + # filler equal to an expert id already present in the same row, + # producing within-row duplicates. Downstream code builds routing_map + # via scatter(top_indices, True), where duplicates collapse, so + # routing_map.sum() < num_tokens * topk. The MoE token dispatcher then + # computes input_splits from routing_map.sum() but uses + # num_out_tokens = num_tokens * topk for the permuted buffer size, + # so sum(input_splits) < input_tensor.shape[0] and all_to_all_single + # raises "RuntimeError: Split sizes doesn't match total dim 0 size". + # See https://github.com/radixark/miles/issues/1002. + n_experts = scores.shape[1] + non_pad = ~padding_mask + # Build [n_tokens, n_experts] used-mask via one_hot of non-pad picks. + # one_hot requires int64; top_indices is int32 per Sample.rollout_routed_experts. + used_mask = ( + torch.nn.functional.one_hot(top_indices.clamp(min=0).long(), num_classes=n_experts) + .mul(non_pad.long().unsqueeze(-1)) + .sum(dim=1) + .bool() ) + # Per-row score-ranked complement: used experts get -inf so they sink to the tail. + masked_scores = scores.masked_fill(used_mask, float("-inf")) + _, sorted_free = masked_scores.sort(dim=1, descending=True) + # The k-th -1 slot in each row gets sorted_free[row, k]. + pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1 + fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to( + top_indices.dtype + ) + top_indices = torch.where(padding_mask, fill_values, top_indices) if return_probs: return scores.gather(1, top_indices), top_indices