From 98d9771b5e73e28ccc137100c914ef044757d6ac Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 24 Dec 2025 13:51:54 -0800 Subject: [PATCH 1/4] reorder runtime functions --- runtime/block_quantization_kernels.cu | 114 +++++++++++++------------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index c950eeb31cb..ce67175f5b0 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -75,6 +75,63 @@ __device__ __inline__ void convertToFloatAndComputeLocalMax( } } +// Fast reciprocal of 2^biased_exp using bit manipulation +// Returns 1.0 for biased_exp==0, otherwise returns 2^(-biased_exp) +constexpr uint32_t FP32_MANTISSA_BITS = 23; +__device__ __forceinline__ float exp2f_rcp(uint8_t biased_exp) { + return (biased_exp == 0) + ? 1 + : __int_as_float( + (254 - biased_exp) + << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) +} + +template < + int ITEMS_PER_THREAD, + typename T, + int ALIGNMENT_1, + int ALIGNMENT_2, + int BLOCK_SCALE_DIM, + int BLOCK_SCALE_ALLOC> +__device__ void block_quantize_to_mxfp8( + const Array& input, + Array<__e4m3, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e8m0, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, + nvfuser_index_t logical_index) { + // Number of threads involved in computing one block scaling factor + constexpr int THREADS_PER_SCALING_FACTOR = 32 / ITEMS_PER_THREAD; + + Array vec_in; + float local_max; + convertToFloatAndComputeLocalMax( + input, vec_in, local_max); + + // Compute the max accross 32/ITEMS_PER_THREAD threads + // This assumes each thread has already computed is local max of 2, 4 (fp32) + // or 2,4, 8 (bf16/fp16) elements. + reduceAcrossThreads(local_max); + float block_max = local_max; + + static constexpr float max_norm_rcp = 1.0f / 448; + __e8m0 exponent = __float2e8m0(block_max * max_norm_rcp); + + // Write out the block scaling factor to global memory. + // This assumes block_size (32) elements in the input were contiguous. + // Only one block scaling factor is written out per 32(assumed block size) + // elements. + int offset = logical_index / 32; + if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { + block_scales[offset] = exponent; + } + + const float block_scale_inverse = exp2f_rcp(exponent.raw()); + +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + output[i] = __float2e4m3(vec_in[i] * block_scale_inverse); + } +} + // A runtime function to compute quantized nvfp4 output (output) and fp8 block // scaling (block_scales) factors from fp32, fp16, bf16 inputs (input). // The function is templatized over input type T (float, __half, __bfloat). @@ -191,62 +248,5 @@ __device__ void block_quantize_to_nvfp4( } } -// Fast reciprocal of 2^biased_exp using bit manipulation -// Returns 1.0 for biased_exp==0, otherwise returns 2^(-biased_exp) -constexpr uint32_t FP32_MANTISSA_BITS = 23; -__device__ __forceinline__ float exp2f_rcp(uint8_t biased_exp) { - return (biased_exp == 0) - ? 1 - : __int_as_float( - (254 - biased_exp) - << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) -} - -template < - int ITEMS_PER_THREAD, - typename T, - int ALIGNMENT_1, - int ALIGNMENT_2, - int BLOCK_SCALE_DIM, - int BLOCK_SCALE_ALLOC> -__device__ void block_quantize_to_mxfp8( - const Array& input, - Array<__e4m3, ITEMS_PER_THREAD, ALIGNMENT_2>& output, - Tensor<__e8m0, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, - nvfuser_index_t logical_index) { - // Number of threads involved in computing one block scaling factor - constexpr int THREADS_PER_SCALING_FACTOR = 32 / ITEMS_PER_THREAD; - - Array vec_in; - float local_max; - convertToFloatAndComputeLocalMax( - input, vec_in, local_max); - - // Compute the max accross 32/ITEMS_PER_THREAD threads - // This assumes each thread has already computed is local max of 2, 4 (fp32) - // or 2,4, 8 (bf16/fp16) elements. - reduceAcrossThreads(local_max); - float block_max = local_max; - - static constexpr float max_norm_rcp = 1.0f / 448; - __e8m0 exponent = __float2e8m0(block_max * max_norm_rcp); - - // Write out the block scaling factor to global memory. - // This assumes block_size (32) elements in the input were contiguous. - // Only one block scaling factor is written out per 32(assumed block size) - // elements. - int offset = logical_index / 32; - if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { - block_scales[offset] = exponent; - } - - const float block_scale_inverse = exp2f_rcp(exponent.raw()); - -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - output[i] = __float2e4m3(vec_in[i] * block_scale_inverse); - } -} - } // namespace bq } // namespace nvf From cba3d60b1438ae7fb620243dc3707ff18fab6354 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 24 Dec 2025 15:14:21 -0800 Subject: [PATCH 2/4] quick refactor to move files around --- csrc/codegen.cpp | 2 +- csrc/runtime/compiled_kernel.cpp | 4 +- runtime/block_quantization_kernels.cu | 204 +++++++++++++++++++++++--- 3 files changed, 183 insertions(+), 27 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 810837c39e8..e12ed646540 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -4677,7 +4677,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline(layout_op->g())); indent() << genCall( - "block_layout::preprocessGroupedMatmulInputSf", + "bq::preprocessGroupedMatmulInputSf", template_args, func_args) << ";\n"; diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index bc9eefebcc1..e9ef09974ad 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -1087,10 +1087,10 @@ std::string _getStructuredCode( code += nvfuser_resources::topk_cu; } if (has_block_layout) { - code += nvfuser_resources::block_layout_cu; + // code += nvfuser_resources::block_layout_cu; } - if (has_block_quantize_op) { + if (has_block_layout || has_block_quantize_op) { code += nvfuser_resources::block_quantization_kernels_cu; } diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index ce67175f5b0..e1d18eff40f 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -9,6 +9,55 @@ namespace nvf { namespace bq { +namespace { + +// TODO: support vectorized store +template +__device__ nvfuser_index_t outputOffsetAfterSwizzlePadding( + const nvfuser_index_t row_idx, + const nvfuser_index_t col_idx, + const nvfuser_index_t padded_col_size) { + constexpr nvfuser_index_t BLOCK_ROW_SIZE = BLOCK_ROW_OUTER * BLOCK_ROW_INNER; + + /* logical dimension of matrix [ row_size, col_size] + * + * while logical domain after padding can be viewed as + * [ (row_tile*BLOCK_ROW_INNER*BLOCK_ROW_OUTER), (col_tile*BLOCK_COL) ] + * where + * row_tile = ceilDiv(row_size / BLOCK_ROW_OUTER * BLOCK_ROW_INNER) + * col_tile = ceilDiv(col_size / BLOCK_COL) + */ + + // we first convert `row_idx` and `col_idx` to the logical index on the 5d + // tensor. + nvfuser_index_t row_tile_idx = row_idx / BLOCK_ROW_SIZE; + nvfuser_index_t row_block_idx = row_idx % BLOCK_ROW_SIZE; + nvfuser_index_t row_block_inner_idx = row_block_idx / BLOCK_ROW_OUTER; + nvfuser_index_t row_block_outer_idx = row_block_idx % BLOCK_ROW_OUTER; + nvfuser_index_t col_tile_idx = col_idx / BLOCK_COL; + nvfuser_index_t col_block_idx = col_idx % BLOCK_COL; + + /* layout for matrix [ row_size, col_size] + * it is viewed + * [row_tile, BLOCK_ROW_INNER, BLOCK_ROW_OUTER, col_tile, BLOCK_COL] + * then transposed with axis (1, 3) + * [row_tile, col_tile, BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL] + * and then made contiguous + * So we can compute the corresponding stride for each dimension + */ + constexpr nvfuser_index_t COL_TILE_STRIDE = BLOCK_ROW_SIZE * BLOCK_COL; + constexpr nvfuser_index_t BLOCK_ROW_OUTER_STRIDE = + BLOCK_ROW_INNER * BLOCK_COL; + constexpr nvfuser_index_t BLOCK_ROW_INNER_STRIDE = BLOCK_COL; + + return row_tile_idx * padded_col_size * BLOCK_ROW_SIZE + + col_tile_idx * COL_TILE_STRIDE + + row_block_outer_idx * BLOCK_ROW_OUTER_STRIDE + + row_block_inner_idx * BLOCK_ROW_INNER_STRIDE + col_block_idx; +} + +} // namespace + // This helper function finds the max of NUM_ELEMENTS (2, 4, or 8) values // using the same number of threads. template @@ -146,18 +195,12 @@ template < int ALIGNMENT_2, int BLOCK_SCALE_DIM, int BLOCK_SCALE_ALLOC> -__device__ void block_quantize_to_nvfp4( +__device__ void block_quantize_to_nvfp4_util( const Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, - nvfuser_index_t logical_index, - Tensor global_scale, - int64_t fp8_scaling_factors_inner_dim = -1, - int64_t alloc_dim0 = -1, - int64_t alloc_dim1 = -1, - int64_t alloc_dim2 = -1, - int64_t alloc_dim3 = -1, - int64_t alloc_dim4 = -1) { + Tensor& global_scale, + int64_t offset) { // Number of threads involved in computing one block scaling factor constexpr int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; @@ -192,6 +235,49 @@ __device__ void block_quantize_to_nvfp4( scaled_max = fminf(1.0f / scaled_max, float_max); } + if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { + block_scales[offset] = clamped_max_fp8; + } + + Array scaled_vals; +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + scaled_vals[i] = vec_in[i] * scaled_max; + } + + Array<__e2m1, ITEMS_PER_THREAD, 1> fp4_vals; + *reinterpret_cast*>( + &fp4_vals[0]) = + __float2e2m1( + *reinterpret_cast*>( + &scaled_vals[0])); + +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + output[i] = fp4_vals[i]; + } +} + +template < + bool USE_GLOBAL_SCALE, + int ITEMS_PER_THREAD, + typename T, + int ALIGNMENT_1, + int ALIGNMENT_2, + int BLOCK_SCALE_DIM, + int BLOCK_SCALE_ALLOC> +__device__ void block_quantize_to_nvfp4( + const Array& input, + Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, + nvfuser_index_t logical_index, + Tensor global_scale, + int64_t fp8_scaling_factors_inner_dim = -1, + int64_t alloc_dim0 = -1, + int64_t alloc_dim1 = -1, + int64_t alloc_dim2 = -1, + int64_t alloc_dim3 = -1, + int64_t alloc_dim4 = -1) { // Write out the block scaling factor to global memory. // This assumes 16 elements in the input were contiguous. // Only one block scaling factor is written out per 16(assumed block size) @@ -224,28 +310,98 @@ __device__ void block_quantize_to_nvfp4( offset = pos_4 * stride_4 + pos_3 * stride_3 + pos_2 * stride_2 + pos_1 * stride_1 + pos_0 * stride_0; } + block_quantize_to_nvfp4_util(input, output, block_scales, global_scale, offset); +} - if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { - block_scales[offset] = clamped_max_fp8; +template < + typename T, + typename Index_T, + int BLOCK_ROW_OUTER, + int BLOCK_ROW_INNER, + int BLOCK_COL, + int UNROLL_FACTOR> +__device__ void preprocessGroupedMatmulInputSf( + T* output, + const T* input, + const nvfuser_index_t row_idx, + const nvfuser_index_t col_idx, + const Index_T* input_offsets, + const Index_T* output_offsets, + const nvfuser_index_t col_size, + const nvfuser_index_t group_size) { + // find corresponding expert_id + int expert_id = group_size - 1; + for (int i = 1; i < group_size; ++i) { + if (row_idx < input_offsets[i]) { + expert_id = i - 1; + break; + } } - Array scaled_vals; -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - scaled_vals[i] = vec_in[i] * scaled_max; + // row idx for current group + nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id]; + // compute output group offset for current group + nvfuser_index_t padded_col_size = + (col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL; + T* out_group_offset = output + output_offsets[expert_id] * padded_col_size; + + // TODO: vectorized load/store instead of for loop + for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) { + nvfuser_index_t index = outputOffsetAfterSwizzlePadding< + BLOCK_ROW_OUTER, + BLOCK_ROW_INNER, + BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size); + out_group_offset[index] = input[i]; } +} - Array<__e2m1, ITEMS_PER_THREAD, 1> fp4_vals; - *reinterpret_cast*>( - &fp4_vals[0]) = - __float2e2m1( - *reinterpret_cast*>( - &scaled_vals[0])); +template < + bool USE_GLOBAL_SCALE, + int ITEMS_PER_THREAD, + typename T, + int ALIGNMENT_1, + int ALIGNMENT_2, + int BLOCK_SCALE_DIM, + int BLOCK_SCALE_ALLOC> +__device__ void grouped_block_quantize_to_nvfp4( + const Array& input, + Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, + nvfuser_index_t logical_index, + Tensor global_scale) { + // Write out the block scaling factor to global memory. + // This assumes 16 elements in the input were contiguous. + // Only one block scaling factor is written out per 16(assumed block size) + // elements. + int offset = logical_index / 16; -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - output[i] = fp4_vals[i]; + if (fp8_scaling_factors_inner_dim > 0) { + auto stride_4 = 1; + auto stride_3 = stride_4 * alloc_dim4; + auto stride_2 = stride_3 * alloc_dim3; + auto stride_1 = stride_2 * alloc_dim2; + auto stride_0 = stride_1 * alloc_dim1; + + auto logical_inner = offset % fp8_scaling_factors_inner_dim; + auto logical_outer = offset / fp8_scaling_factors_inner_dim; + + // The allocation domain swizzle logic is: + // m, k -> m, k/4, 4 + // m, k/4, 4 -> m/128, 128, k/4, 4 -> + // m/128, 4(m), 32, k/4, 4(k) -> + // m/128, k/4, 32, 4(m), 4(k) + + auto pos_4 = logical_inner % 4; + auto pos_1 = logical_inner / 4; + auto pos_t = logical_outer % 128; + auto pos_0 = logical_outer / 128; + auto pos_3 = pos_t / 32; + auto pos_2 = pos_t % 32; + + offset = pos_4 * stride_4 + pos_3 * stride_3 + pos_2 * stride_2 + + pos_1 * stride_1 + pos_0 * stride_0; } + block_quantize_to_nvfp4_util(input, output, block_scales, global_scale, offset); } } // namespace bq From f790909ef48f3efab0773468264ca3e9d29f8f41 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 Dec 2025 12:30:01 -0800 Subject: [PATCH 3/4] removing unused runtime header --- CMakeLists.txt | 1 - csrc/runtime/compiled_kernel.cpp | 3 - runtime/block_layout.cu | 102 ------------------------------- 3 files changed, 106 deletions(-) delete mode 100644 runtime/block_layout.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index ccda3d89fb5..0840779f26a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1435,7 +1435,6 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/block_sync_atomic.cu ${NVFUSER_ROOT}/runtime/block_sync_default.cu ${NVFUSER_ROOT}/runtime/block_welford_outer.cu - ${NVFUSER_ROOT}/runtime/block_layout.cu ${NVFUSER_ROOT}/runtime/block_quantization_kernels.cu ${NVFUSER_ROOT}/runtime/broadcast.cu ${NVFUSER_ROOT}/runtime/casts.cu diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index e9ef09974ad..4ce528a510e 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -1086,9 +1086,6 @@ std::string _getStructuredCode( if (has_topk) { code += nvfuser_resources::topk_cu; } - if (has_block_layout) { - // code += nvfuser_resources::block_layout_cu; - } if (has_block_layout || has_block_quantize_op) { code += nvfuser_resources::block_quantization_kernels_cu; diff --git a/runtime/block_layout.cu b/runtime/block_layout.cu deleted file mode 100644 index 22935dcb57a..00000000000 --- a/runtime/block_layout.cu +++ /dev/null @@ -1,102 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -namespace nvf::block_layout { - -namespace { - -// TODO: support vectorized store -template -__device__ nvfuser_index_t outputOffsetAfterSwizzlePadding( - const nvfuser_index_t row_idx, - const nvfuser_index_t col_idx, - const nvfuser_index_t padded_col_size) { - constexpr nvfuser_index_t BLOCK_ROW_SIZE = BLOCK_ROW_OUTER * BLOCK_ROW_INNER; - - /* logical dimension of matrix [ row_size, col_size] - * - * while logical domain after padding can be viewed as - * [ (row_tile*BLOCK_ROW_INNER*BLOCK_ROW_OUTER), (col_tile*BLOCK_COL) ] - * where - * row_tile = ceilDiv(row_size / BLOCK_ROW_OUTER * BLOCK_ROW_INNER) - * col_tile = ceilDiv(col_size / BLOCK_COL) - */ - - // we first convert `row_idx` and `col_idx` to the logical index on the 5d - // tensor. - nvfuser_index_t row_tile_idx = row_idx / BLOCK_ROW_SIZE; - nvfuser_index_t row_block_idx = row_idx % BLOCK_ROW_SIZE; - nvfuser_index_t row_block_inner_idx = row_block_idx / BLOCK_ROW_OUTER; - nvfuser_index_t row_block_outer_idx = row_block_idx % BLOCK_ROW_OUTER; - nvfuser_index_t col_tile_idx = col_idx / BLOCK_COL; - nvfuser_index_t col_block_idx = col_idx % BLOCK_COL; - - /* layout for matrix [ row_size, col_size] - * it is viewed - * [row_tile, BLOCK_ROW_INNER, BLOCK_ROW_OUTER, col_tile, BLOCK_COL] - * then transposed with axis (1, 3) - * [row_tile, col_tile, BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL] - * and then made contiguous - * So we can compute the corresponding stride for each dimension - */ - constexpr nvfuser_index_t COL_TILE_STRIDE = BLOCK_ROW_SIZE * BLOCK_COL; - constexpr nvfuser_index_t BLOCK_ROW_OUTER_STRIDE = - BLOCK_ROW_INNER * BLOCK_COL; - constexpr nvfuser_index_t BLOCK_ROW_INNER_STRIDE = BLOCK_COL; - - return row_tile_idx * padded_col_size * BLOCK_ROW_SIZE + - col_tile_idx * COL_TILE_STRIDE + - row_block_outer_idx * BLOCK_ROW_OUTER_STRIDE + - row_block_inner_idx * BLOCK_ROW_INNER_STRIDE + col_block_idx; -} - -} // namespace - -template < - typename T, - typename Index_T, - int BLOCK_ROW_OUTER, - int BLOCK_ROW_INNER, - int BLOCK_COL, - int UNROLL_FACTOR> -__device__ void preprocessGroupedMatmulInputSf( - T* output, - const T* input, - const nvfuser_index_t row_idx, - const nvfuser_index_t col_idx, - const Index_T* input_offsets, - const Index_T* output_offsets, - const nvfuser_index_t col_size, - const nvfuser_index_t group_size) { - // find corresponding expert_id - int expert_id = group_size - 1; - for (int i = 1; i < group_size; ++i) { - if (row_idx < input_offsets[i]) { - expert_id = i - 1; - break; - } - } - - // row idx for current group - nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id]; - // compute output group offset for current group - nvfuser_index_t padded_col_size = - (col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL; - T* out_group_offset = output + output_offsets[expert_id] * padded_col_size; - - // TODO: vectorized load/store instead of for loop - for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) { - nvfuser_index_t index = outputOffsetAfterSwizzlePadding< - BLOCK_ROW_OUTER, - BLOCK_ROW_INNER, - BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size); - out_group_offset[index] = input[i]; - } -} - -} // namespace nvf::block_layout From d51d8c52f1a41088ce6fa54f598462ce261eb62b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 Dec 2025 16:23:17 -0800 Subject: [PATCH 4/4] quick draft on the runtime function --- runtime/block_quantization_kernels.cu | 63 +++++++++++++-------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index e1d18eff40f..995d1531d38 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -357,8 +357,13 @@ __device__ void preprocessGroupedMatmulInputSf( template < bool USE_GLOBAL_SCALE, + int BLOCK_ROW_OUTER, + int BLOCK_ROW_INNER, + int BLOCK_COL, + int UNROLL_FACTOR, int ITEMS_PER_THREAD, typename T, + typename Index_T, int ALIGNMENT_1, int ALIGNMENT_2, int BLOCK_SCALE_DIM, @@ -367,40 +372,34 @@ __device__ void grouped_block_quantize_to_nvfp4( const Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, - nvfuser_index_t logical_index, + const nvfuser_index_t row_idx, + const nvfuser_index_t col_idx, + const Index_T* input_offsets, + const Index_T* output_offsets, + const nvfuser_index_t col_size, + const nvfuser_index_t group_size, Tensor global_scale) { - // Write out the block scaling factor to global memory. - // This assumes 16 elements in the input were contiguous. - // Only one block scaling factor is written out per 16(assumed block size) - // elements. - int offset = logical_index / 16; - - if (fp8_scaling_factors_inner_dim > 0) { - auto stride_4 = 1; - auto stride_3 = stride_4 * alloc_dim4; - auto stride_2 = stride_3 * alloc_dim3; - auto stride_1 = stride_2 * alloc_dim2; - auto stride_0 = stride_1 * alloc_dim1; - - auto logical_inner = offset % fp8_scaling_factors_inner_dim; - auto logical_outer = offset / fp8_scaling_factors_inner_dim; - - // The allocation domain swizzle logic is: - // m, k -> m, k/4, 4 - // m, k/4, 4 -> m/128, 128, k/4, 4 -> - // m/128, 4(m), 32, k/4, 4(k) -> - // m/128, k/4, 32, 4(m), 4(k) - - auto pos_4 = logical_inner % 4; - auto pos_1 = logical_inner / 4; - auto pos_t = logical_outer % 128; - auto pos_0 = logical_outer / 128; - auto pos_3 = pos_t / 32; - auto pos_2 = pos_t % 32; - - offset = pos_4 * stride_4 + pos_3 * stride_3 + pos_2 * stride_2 + - pos_1 * stride_1 + pos_0 * stride_0; + // find corresponding expert_id + int expert_id = group_size - 1; + for (int i = 1; i < group_size; ++i) { + if (row_idx < input_offsets[i]) { + expert_id = i - 1; + break; + } } + // row idx for current group + nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id]; + // compute output group offset for current group + nvfuser_index_t padded_col_size = + (col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL; + nvfuser_index_t out_group_offset = output_offsets[expert_id] * padded_col_size; + // compute the offset + nvfuser_index_t index = outputOffsetAfterSwizzlePadding< + BLOCK_ROW_OUTER, + BLOCK_ROW_INNER, + BLOCK_COL>(c_row_idx, col_idx, padded_col_size); + nvfuser_index_t offset = out_group_offset + index; + block_quantize_to_nvfp4_util(input, output, block_scales, global_scale, offset); }