diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index fbba27941c..aa7cb50e8b 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -19,7 +19,7 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id const int tid = threadIdx.x; const int idx = bid * blockDim.x + tid; - if (idx >= num_rows * topK) return; + if (idx >= static_cast(num_rows) * topK) return; int source_row = sorted_row_id[idx]; int source_token_id = source_row / topK; @@ -27,10 +27,10 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id if (idx >= num_out_tokens) { // Set the indices of dropped tokens to -1 - row_id_map[source_topK_id * num_rows + source_token_id] = -1; + row_id_map[static_cast(source_topK_id) * num_rows + source_token_id] = -1; } else { // Create a row id map for subsequent unpermute operation - row_id_map[source_topK_id * num_rows + source_token_id] = idx; + row_id_map[static_cast(source_topK_id) * num_rows + source_token_id] = idx; } } @@ -42,7 +42,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one dest token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -65,7 +65,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute frag_elem[kElementsPerAccess]; TCompute frag_sum[kElementsPerAccess]; - int source_row = row_id_map[source_token]; + int64_t source_row = row_id_map[source_token]; // source_row == -1 represents a dropped token if (source_row != -1) { @@ -134,7 +134,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one source token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -172,7 +172,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac for (int k = 0; k < topKTile; k++) { if (k == topK) break; - int dest_row = row_id_map[index]; + int64_t dest_row = row_id_map[index]; index += num_rows; if (dest_row != -1) { @@ -239,7 +239,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, // moe_permute_fwd int threads = 64; - int blocks = (num_rows * topK + threads - 1) / threads; + int blocks = (static_cast(num_rows) * topK + threads - 1) / threads; moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); @@ -371,6 +371,13 @@ void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes int *keys_out, int *values_in, int *values_out, size_t num_items) { NVTE_API_CALL(nvte_device_radix_sort_pairs); - cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, keys_in, keys_out, values_in, - values_out, num_items); + // Sort keys as uint32_t so any negative-int sentinel (e.g. `-1` placed by an + // expert-parallel rank mask) becomes a large unsigned value and lands at the + // tail of the sorted output, matching the existing capacity-drop convention + // (drops encoded as a large positive expert id) and the + // `idx >= num_out_tokens` drop branch in moe_permute_row_map. + auto *u_keys_in = reinterpret_cast(keys_in); + auto *u_keys_out = reinterpret_cast(keys_out); + cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, u_keys_in, u_keys_out, + values_in, values_out, num_items); }