Skip to content

feat: implement moe fused topk layer.#983

Open
XuZhang99 wants to merge 1 commit intojd-opensource:mainfrom
XuZhang99:feature/topk
Open

feat: implement moe fused topk layer.#983
XuZhang99 wants to merge 1 commit intojd-opensource:mainfrom
XuZhang99:feature/topk

Conversation

@XuZhang99
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fused top-k layer for Mixture of Experts, including CUDA kernels for both softmax and sigmoid scoring functions. A critical security vulnerability has been identified where several CUDA kernels use 32-bit signed integers (int) for global memory offset calculations and to represent the number of tokens. This can lead to integer overflows when processing very large batches of tokens, resulting in out-of-bounds memory access on the GPU, potentially causing crashes, information leakage, or memory corruption. Additionally, the review found critical issues that will prevent compilation or cause incorrect behavior, such as a wrong include path and a bug where a parameter is ignored in the softmax path. Improvements are also needed for const correctness, type safety, and coding style, like missing newlines at the end of files.


// We finally start setting up the read pointers for each thread. First, each
// thread jumps to the start of the row it will read.
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Pointer arithmetic using a 32-bit signed integer for the offset (thread_row * ELTS_PER_ROW) can overflow, causing thread_row_ptr to point to an invalid memory location. This should be calculated using 64-bit integers.

  const T* thread_row_ptr = input + static_cast<int64_t>(thread_row) * ELTS_PER_ROW;


const int num_experts = static_cast<int>(gating_output.size(-1));
const int num_tokens = static_cast<int>(gating_output.size(0));
const int topk = static_cast<int>(topk_weights.size(-1));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The number of tokens is explicitly cast to a 32-bit signed integer. In large-scale inference scenarios or high-throughput batch processing, the number of tokens (gating_output.size(0)) can exceed INT_MAX (2,147,483,647). This will cause an integer overflow, leading to incorrect behavior and potential out-of-bounds access in the kernels that receive this value.

__shared__ float normalizing_factor;
__shared__ float float_max;

const int thread_row_offset = blockIdx.x * num_cols;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The calculation of thread_row_offset is vulnerable to integer overflow if the product of the token index and the number of columns exceeds INT_MAX. This will lead to out-of-bounds memory access during the softmax computation.

  const int64_t thread_row_offset = static_cast<int64_t>(blockIdx.x) * num_cols;

float* output,
const int num_cols,
const float* correction_bias) {
const int thread_row_offset = blockIdx.x * num_cols;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The calculation of thread_row_offset using 32-bit signed integers can overflow if blockIdx.x * num_cols exceeds 2,147,483,647. Since blockIdx.x represents the token index and num_cols is the number of experts, this overflow is plausible in high-throughput scenarios. An overflow here will result in out-of-bounds reads and writes on the GPU.

  const int64_t thread_row_offset = static_cast<int64_t>(blockIdx.x) * num_cols;

@XuZhang99 XuZhang99 force-pushed the feature/topk branch 3 times, most recently from 880e028 to 48e0763 Compare March 4, 2026 05:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant