Skip to content

[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907

Open
jing-4369 wants to merge 4 commits intoNVIDIA:mainfrom
jing-4369:fix/moe-permute-int-overflow-and-minus-one
Open

[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
jing-4369 wants to merge 4 commits intoNVIDIA:mainfrom
jing-4369:fix/moe-permute-int-overflow-and-minus-one

Conversation

@jing-4369
Copy link
Copy Markdown

@jing-4369 jing-4369 commented Apr 21, 2026

Fixes #2908 — full description, repros, and DeepSeek-V3 context there.

Changes

  • permutation.cu — widen source_token, source_row, dest_row to int64_t inside moe_unpermute_kernel and moe_permute_kernel so row * num_cols stays 64-bit. Simplify moe_permute_row_map to only process the valid [0, num_out_tokens) range; launcher grid becomes num_out_tokens blocks.
  • permutation.cpp — advance sorted_row_id_ptr past the num_minus_ones sentinel prefix left by cub::DeviceRadixSort (signed ascending), and pre-fill row_id_map with -1 via torch::full so dropped slots are marked without the kernel ever dereferencing a sentinel.

No public API / dtype changes. +17 / -18 lines across the two files.

Test plan

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 21, 2026

Greptile Summary

This PR fixes two correctness bugs in MoE token permutation: int32 overflow in pointer-offset arithmetic for large activation tensors, and incorrect handling of -1 sentinel expert IDs (used by DeepSeek-V3-style routing) by switching cub::DeviceRadixSort to treat keys as uint32_t so sentinels sort to the tail rather than the head.

  • source_token, source_row, and dest_row are widened to int64_t in moe_unpermute_kernel and moe_permute_kernel so row * num_cols pointer arithmetic stays 64-bit on large activations.
  • nvte_device_radix_sort_pairs now reinterprets keys as uint32_t, placing any -1 sentinel (0xFFFFFFFF unsigned) at the tail of the sorted output where the existing idx >= num_out_tokens guard in moe_permute_row_map marks them as dropped — cleanly replacing the signed-sort + pointer-advance approach described in the issue.

Confidence Score: 3/5

The forward permute path still contains an unwidened int32 index computation in moe_permute_row_map that can silently corrupt the row_id_map for large-model configurations, undermining the very overflow fix this PR introduces.

The widening applied to moe_unpermute_kernel and moe_permute_kernel is correct, and the uint32 radix-sort sentinel fix is sound. However, moe_permute_row_map — the kernel that builds the row_id_map consumed by every subsequent kernel — still computes source_topK_id * num_rows and num_rows * topK as int32. For any model where num_rows * topK exceeds 2^31 the map writes land at wrong addresses, producing silently wrong permutation output in the forward pass.

transformer_engine/common/permutation/permutation.cu — specifically the moe_permute_row_map kernel (lines 13–35) and its launcher (line 242), where int32 index arithmetic was not widened alongside the rest of the file.

Important Files Changed

Filename Overview
transformer_engine/common/permutation/permutation.cu Widens source_token/source_row/dest_row to int64_t in unpermute/permute kernels and switches radix sort to uint32 so -1 sentinels land at the tail; moe_permute_row_map's own int32 index arithmetic is not widened and can still overflow for large num_rows*topK

Comments Outside Diff (1)

  1. transformer_engine/common/permutation/permutation.cu, line 22-33 (link)

    P1 moe_permute_row_map index computation still overflows int32

    num_rows * topK on the guard line and source_topK_id * num_rows + source_token_id in the write are both int * int expressions. The PR widens the same pattern in moe_unpermute_kernel and moe_permute_kernel, but this kernel is left behind. For topK=128 and num_rows >= 2^24 (~16M), source_topK_id * num_rows already exceeds INT_MAX, silently wrapping to a negative value and making row_id_map[...] = -1/idx write to an out-of-bounds (or wrong) slot — corrupting the map used by every subsequent kernel in both the forward and backward paths. The same int * int product appears in the launcher on line 242 when computing blocks, which will also produce a wrong grid dimension under the same conditions.

Reviews (6): Last reviewed commit: "Switch radix sort keys to uint32_t to fi..." | Re-trigger Greptile

Comment on lines +59 to +60
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Negative num_minus_ones becomes enormous size_t offset

num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:

sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);

involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:

TORCH_CHECK(num_out_tokens <= num_tokens * topK,
            "num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
            num_tokens * topK, ")");

…permute

Two independent bugs in transformer_engine/common/permutation/permutation.cu
and the PyTorch extension caller reproduce on main (264da2b) and v2.13:

1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel.
   `source_token * num_cols` and `source_row * num_cols` are computed with
   int, so for long-sequence MoE workloads where num_out_tokens * num_cols
   reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset
   wraps and the kernel either reads garbage or raises
   `an illegal memory access was encountered`.
   Widening source_token, source_row and dest_row to int64_t inside the
   kernels keeps the index arithmetic in 64 bits without changing any
   public types.

2. Incorrect handling of -1 sentinels in the routing indices.
   Libraries such as DeepEP (and any expert-parallel mask that sets
   non-local (token, slot) pairs to -1) feed a routing_map that contains
   -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so
   those sentinels land at the HEAD of sorted_row_id, not the tail.
   moe_permute_row_map currently writes -1 only for idx >= num_out_tokens
   and reads the sentinel prefix as if it were a valid sorted id,
   producing bogus row_id_map writes (for instance
   `source_row / topK == 0, source_row % topK == -1`).

   The caller now advances sorted_row_id_ptr past the num_minus_ones
   prefix and pre-fills row_id_map with -1 via torch::full, so the
   kernel only processes the valid suffix and never dereferences a
   sentinel.  The launcher's grid switches from num_rows*topK blocks
   to num_out_tokens blocks to match the new valid range.

No behaviour change on happy-path routing_map (no -1, no overflow).
Reproducers:

- 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0
  on bf16 with current main; 0.0 with this patch.
- num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises
  CUDA illegal memory access at permutation.cu:252; with this patch
  it succeeds.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from a909254 to ce34774 Compare April 21, 2026 07:58
Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 num_tokens * topK still computed as int * int

num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:

Suggested change
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK;
NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", total_tokens, ")");
const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens);
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much

@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from ba39fba to cd64aa5 Compare April 21, 2026 08:14
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast
num_minus_ones to size_t before the pointer advance, so a negative
num_minus_ones (from an invalid num_out_tokens) cannot silently wrap
into a huge pointer offset.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from cd64aa5 to b73a1f9 Compare April 21, 2026 08:22
@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 21, 2026
Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?

Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much


int threads = 64;
int blocks = (num_rows * topK + threads - 1) / threads;
int blocks = (num_out_tokens + threads - 1) / threads;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is correct here but has an implied prerequisite that host prefills the buffer with -1 and shift the ptr by num_minus_ones (what you did in the other file). Better make it more explicit with a comment so no regression will happen by someone accidentally changing this behavior and mess up the number of blocks here. Something like:

"// row_id_map MUST be pre-initialized to -1; sorted_row_id MUST point past the sentinel prefix"

num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", num_tokens * topK, ")");
const int num_minus_ones = num_tokens * topK - num_out_tokens;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity))

See in this file tests/pytorch/test_permutation.py, in pytorch_permute_index_map, we have:

sorted_indices[:num_out_tokens] (keeps the head),
so I'd expect test_permutation_index_map[..., num_out_tokens=2039, ...] to fail. We can run the te_ci to confirm it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as uint32_t instead of int32_t. So, -1 becomes UINT_MAX and sorts to the tail, unifying both capacity-dropping and dropless under the original idx >= num_out_tokens --> drop logic. That removes the need for the prefix shift you did, and the row_id_map pre-fill. This just needs expert_id to be <= UINT_MAX, which I do not think we are reaching there anytime soon

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 4, 2026

/te_ci pytorch

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 6, 2026

/te-ci pytorch

@tdophung
Copy link
Copy Markdown
Collaborator

tdophung commented May 6, 2026

/te-ci pytorch L0

The MoE permute path was correct for the existing capacity-drop convention
(drops encoded as a large positive expert id, sorted to the tail by the
signed cub::DeviceRadixSort), but it broke for callers that mark dropped
(token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP).
With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while
moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at
the tail.

Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so
-1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the
existing capacity-drop convention. The kernel and caller sides are
unchanged - this is a one-place fix that makes both drop conventions
land in the existing drop branch.

Also widen the loop-carried indices in moe_unpermute_kernel and
moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`)
to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long-
context training (hidden = 7168, topK = 8): once `num_out_tokens *
num_cols` reaches 2**31 the int product wraps and the kernel either
silently corrupts rows or raises CUDA `illegal memory access`.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369
Copy link
Copy Markdown
Author

@tdophung Pushed the unsigned-sort approach in 4f46dc2. Net diff is now +13/-6 in permutation.cu only; permutation.cpp is unchanged from upstream, and the earlier comments around NVTE_CHECK / launcher block-count no longer apply.

Could you re-trigger /te-ci pytorch when convenient?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling

3 participants