[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907jing-4369 wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR fixes two correctness bugs in MoE token permutation: int32 overflow in pointer-offset arithmetic for large activation tensors, and incorrect handling of
Confidence Score: 3/5The forward permute path still contains an unwidened int32 index computation in moe_permute_row_map that can silently corrupt the row_id_map for large-model configurations, undermining the very overflow fix this PR introduces. The widening applied to moe_unpermute_kernel and moe_permute_kernel is correct, and the uint32 radix-sort sentinel fix is sound. However, moe_permute_row_map — the kernel that builds the row_id_map consumed by every subsequent kernel — still computes source_topK_id * num_rows and num_rows * topK as int32. For any model where num_rows * topK exceeds 2^31 the map writes land at wrong addresses, producing silently wrong permutation output in the forward pass. transformer_engine/common/permutation/permutation.cu — specifically the moe_permute_row_map kernel (lines 13–35) and its launcher (line 242), where int32 index arithmetic was not widened alongside the rest of the file. Important Files Changed
|
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int); |
There was a problem hiding this comment.
Negative
num_minus_ones becomes enormous size_t offset
num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:
TORCH_CHECK(num_out_tokens <= num_tokens * topK,
"num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
num_tokens * topK, ")");…permute Two independent bugs in transformer_engine/common/permutation/permutation.cu and the PyTorch extension caller reproduce on main (264da2b) and v2.13: 1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel. `source_token * num_cols` and `source_row * num_cols` are computed with int, so for long-sequence MoE workloads where num_out_tokens * num_cols reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset wraps and the kernel either reads garbage or raises `an illegal memory access was encountered`. Widening source_token, source_row and dest_row to int64_t inside the kernels keeps the index arithmetic in 64 bits without changing any public types. 2. Incorrect handling of -1 sentinels in the routing indices. Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map that contains -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so those sentinels land at the HEAD of sorted_row_id, not the tail. moe_permute_row_map currently writes -1 only for idx >= num_out_tokens and reads the sentinel prefix as if it were a valid sorted id, producing bogus row_id_map writes (for instance `source_row / topK == 0, source_row % topK == -1`). The caller now advances sorted_row_id_ptr past the num_minus_ones prefix and pre-fills row_id_map with -1 via torch::full, so the kernel only processes the valid suffix and never dereferences a sentinel. The launcher's grid switches from num_rows*topK blocks to num_out_tokens blocks to match the new valid range. No behaviour change on happy-path routing_map (no -1, no overflow). Reproducers: - 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0 on bf16 with current main; 0.0 with this patch. - num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises CUDA illegal memory access at permutation.cu:252; with this patch it succeeds. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
a909254 to
ce34774
Compare
for more information, see https://pre-commit.ci
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
num_tokens * topK still computed as int * int
num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); | |
| const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK; | |
| NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens, | |
| ") must not exceed num_tokens*topK (", total_tokens, ")"); | |
| const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens); | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
There was a problem hiding this comment.
This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much
ba39fba to
cd64aa5
Compare
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast num_minus_ones to size_t before the pointer advance, so a negative num_minus_ones (from an invalid num_out_tokens) cannot silently wrap into a huge pointer offset. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
cd64aa5 to
b73a1f9
Compare
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much
|
|
||
| int threads = 64; | ||
| int blocks = (num_rows * topK + threads - 1) / threads; | ||
| int blocks = (num_out_tokens + threads - 1) / threads; |
There was a problem hiding this comment.
this is correct here but has an implied prerequisite that host prefills the buffer with -1 and shift the ptr by num_minus_ones (what you did in the other file). Better make it more explicit with a comment so no regression will happen by someone accidentally changing this behavior and mess up the number of blocks here. Something like:
"// row_id_map MUST be pre-initialized to -1; sorted_row_id MUST point past the sentinel prefix"
| num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; | ||
| NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens, | ||
| ") must not exceed num_tokens*topK (", num_tokens * topK, ")"); | ||
| const int num_minus_ones = num_tokens * topK - num_out_tokens; |
There was a problem hiding this comment.
This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity))
See in this file tests/pytorch/test_permutation.py, in pytorch_permute_index_map, we have:
sorted_indices[:num_out_tokens] (keeps the head),
so I'd expect test_permutation_index_map[..., num_out_tokens=2039, ...] to fail. We can run the te_ci to confirm it.
There was a problem hiding this comment.
I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as uint32_t instead of int32_t. So, -1 becomes UINT_MAX and sorts to the tail, unifying both capacity-dropping and dropless under the original idx >= num_out_tokens --> drop logic. That removes the need for the prefix shift you did, and the row_id_map pre-fill. This just needs expert_id to be <= UINT_MAX, which I do not think we are reaching there anytime soon
There was a problem hiding this comment.
Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.
There was a problem hiding this comment.
Here is the CI pipeline: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50478896
It failed in the expected tests
|
/te_ci pytorch |
|
/te-ci pytorch |
|
/te-ci pytorch L0 |
The MoE permute path was correct for the existing capacity-drop convention (drops encoded as a large positive expert id, sorted to the tail by the signed cub::DeviceRadixSort), but it broke for callers that mark dropped (token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP). With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at the tail. Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so -1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the existing capacity-drop convention. The kernel and caller sides are unchanged - this is a one-place fix that makes both drop conventions land in the existing drop branch. Also widen the loop-carried indices in moe_unpermute_kernel and moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`) to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long- context training (hidden = 7168, topK = 8): once `num_out_tokens * num_cols` reaches 2**31 the int product wraps and the kernel either silently corrupts rows or raises CUDA `illegal memory access`. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
Fixes #2908 — full description, repros, and DeepSeek-V3 context there.
Changes
permutation.cu— widensource_token,source_row,dest_rowtoint64_tinsidemoe_unpermute_kernelandmoe_permute_kernelsorow * num_colsstays 64-bit. Simplifymoe_permute_row_mapto only process the valid[0, num_out_tokens)range; launcher grid becomesnum_out_tokensblocks.permutation.cpp— advancesorted_row_id_ptrpast thenum_minus_onessentinel prefix left bycub::DeviceRadixSort(signed ascending), and pre-fillrow_id_mapwith-1viatorch::fullso dropped slots are marked without the kernel ever dereferencing a sentinel.No public API / dtype changes.
+17 / -18lines across the two files.Test plan
routing_map(no-1, offsets within int32) — unchanged.-1-sentinel repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 →max |TE - ref| = 0.0on bf16 (was4.56e0).int32-boundary repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 → no longer raisesillegal memory access; matches reference.tests/pytorch/test_permutation.pyvia CI.