Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions transformer_engine/common/permutation/permutation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id
const int tid = threadIdx.x;
const int idx = bid * blockDim.x + tid;

if (idx >= num_rows * topK) return;
if (idx >= static_cast<int64_t>(num_rows) * topK) return;

int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;
Comment on lines 24 to 26
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 After the uint32_t-sort change, -1 sentinel entries land at the tail of sorted_row_id (positions ≥ num_out_tokens). The drop branch then computes source_token_id = (-1) / topK and source_topK_id = (-1) % topK. For topK > 1, C++ truncates toward zero, giving source_token_id = 0 and source_topK_id = -1, so the write becomes row_id_map[-1 * num_rows + 0]num_rows words before the buffer start. For topK = 1 the write lands at row_id_map[-1]. Both cases silently corrupt adjacent device memory. A simple early-exit on source_row < 0 closes this gap without touching the caller.

Suggested change
int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;
int source_row = sorted_row_id[idx];
if (source_row < 0) return; // skip -1 sentinel entries
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;


if (idx >= num_out_tokens) {
// Set the indices of dropped tokens to -1
row_id_map[source_topK_id * num_rows + source_token_id] = -1;
row_id_map[static_cast<int64_t>(source_topK_id) * num_rows + source_token_id] = -1;
} else {
// Create a row id map for subsequent unpermute operation
row_id_map[source_topK_id * num_rows + source_token_id] = idx;
row_id_map[static_cast<int64_t>(source_topK_id) * num_rows + source_token_id] = idx;
}
}

Expand All @@ -42,7 +42,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
TCompute *s_prob = reinterpret_cast<TCompute *>(s_mem);

// Each block corresponds to one dest token
const int source_token = blockIdx.x;
const int64_t source_token = blockIdx.x;
const int tid = threadIdx.x;

if (hasProb) {
Expand All @@ -65,7 +65,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
TCompute frag_elem[kElementsPerAccess];
TCompute frag_sum[kElementsPerAccess];

int source_row = row_id_map[source_token];
int64_t source_row = row_id_map[source_token];

// source_row == -1 represents a dropped token
if (source_row != -1) {
Expand Down Expand Up @@ -134,7 +134,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
TCompute *s_prob = reinterpret_cast<TCompute *>(s_mem);

// Each block corresponds to one source token
const int source_token = blockIdx.x;
const int64_t source_token = blockIdx.x;
const int tid = threadIdx.x;

if (hasProb) {
Expand Down Expand Up @@ -172,7 +172,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
for (int k = 0; k < topKTile; k++) {
if (k == topK) break;

int dest_row = row_id_map[index];
int64_t dest_row = row_id_map[index];
index += num_rows;

if (dest_row != -1) {
Expand Down Expand Up @@ -239,7 +239,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
// moe_permute_fwd

int threads = 64;
int blocks = (num_rows * topK + threads - 1) / threads;
int blocks = (static_cast<int64_t>(num_rows) * topK + threads - 1) / threads;

moe_permute_row_map<<<blocks, threads, 0, stream>>>(sorted_row_id, row_id_map, num_rows, topK,
num_out_tokens);
Expand Down Expand Up @@ -371,6 +371,13 @@ void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes
int *keys_out, int *values_in, int *values_out,
size_t num_items) {
NVTE_API_CALL(nvte_device_radix_sort_pairs);
cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, keys_in, keys_out, values_in,
values_out, num_items);
// Sort keys as uint32_t so any negative-int sentinel (e.g. `-1` placed by an
// expert-parallel rank mask) becomes a large unsigned value and lands at the
// tail of the sorted output, matching the existing capacity-drop convention
// (drops encoded as a large positive expert id) and the
// `idx >= num_out_tokens` drop branch in moe_permute_row_map.
auto *u_keys_in = reinterpret_cast<uint32_t *>(keys_in);
auto *u_keys_out = reinterpret_cast<uint32_t *>(keys_out);
cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, u_keys_in, u_keys_out,
values_in, values_out, num_items);
}