Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions miles/utils/replay_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,38 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs):

padding_mask = top_indices == -1
if padding_mask.any():
top_indices[padding_mask] = (
torch.arange(padding_mask.sum(), device=top_indices.device, dtype=top_indices.dtype)
% scores.shape[1]
# Fill -1 padding slots with experts that are NOT already in the row's
# existing topk picks, ranked by score (most plausible filler first).
#
# The previous implementation `arange(N) % num_experts` could yield a
# filler equal to an expert id already present in the same row,
# producing within-row duplicates. Downstream code builds routing_map
# via scatter(top_indices, True), where duplicates collapse, so
# routing_map.sum() < num_tokens * topk. The MoE token dispatcher then
# computes input_splits from routing_map.sum() but uses
# num_out_tokens = num_tokens * topk for the permuted buffer size,
# so sum(input_splits) < input_tensor.shape[0] and all_to_all_single
# raises "RuntimeError: Split sizes doesn't match total dim 0 size".
# See https://github.com/radixark/miles/issues/1002.
n_experts = scores.shape[1]
non_pad = ~padding_mask
# Build [n_tokens, n_experts] used-mask via one_hot of non-pad picks.
# one_hot requires int64; top_indices is int32 per Sample.rollout_routed_experts.
used_mask = (
torch.nn.functional.one_hot(top_indices.clamp(min=0).long(), num_classes=n_experts)
.mul(non_pad.long().unsqueeze(-1))
.sum(dim=1)
.bool()
)
# Per-row score-ranked complement: used experts get -inf so they sink to the tail.
masked_scores = scores.masked_fill(used_mask, float("-inf"))
_, sorted_free = masked_scores.sort(dim=1, descending=True)
# The k-th -1 slot in each row gets sorted_free[row, k].
pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1
fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(
top_indices.dtype
)
top_indices = torch.where(padding_mask, fill_values, top_indices)

if return_probs:
return scores.gather(1, top_indices), top_indices
Expand Down
Loading