From c7c1a765f4bdb50f18c919da501bf925931468f0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 11 Feb 2026 17:29:35 +0000 Subject: [PATCH 01/31] Implemented the kernel with split dbias Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 114 +++++++++--------- transformer_engine/common/activation/gelu.cu | 8 +- transformer_engine/common/activation/relu.cu | 8 +- .../common/activation/swiglu.cu | 4 +- transformer_engine/common/cast/cast.cu | 2 +- .../common/cast/core/common.cuh | 94 +++++++++++++++ .../common/cast/dispatch/quantize.cuh | 11 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 58 ++++----- .../common/include/transformer_engine/cast.h | 12 +- 9 files changed, 199 insertions(+), 112 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 6557c83773..29a02124af 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,8 +58,7 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise, - const bool is_single_tensor) + const size_t scales_stride_colwise) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -169,10 +168,8 @@ void compute_ref(const ProcessingMethod processing_method, } } - if (is_single_tensor) { - for (size_t j = 0; j < cols; ++j) { - output_dbias[j] = static_cast(output_dbias_fp32[j]); - } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); } } @@ -250,12 +247,16 @@ void performTest(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; + const bool compute_dbias = (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT); + const size_t rows = logical_shape_vec[0]; const size_t cols = logical_shape_vec[1]; size_t elts_num = 0; size_t rowwise_sfs_num = 0; size_t colwise_sfs_num = 0; + size_t sum_of_last_dims = 0; std::vector rowwise_scales_first_dim(num_tensors, 0); std::vector rowwise_scales_last_dim(num_tensors, 0); @@ -263,6 +264,7 @@ void performTest(const ProcessingMethod processing_method, std::vector colwise_scales_first_dim(num_tensors, 0); std::vector colwise_scales_last_dim(num_tensors, 0); std::vector colwise_scales_offset(num_tensors + 1, 0); + std::vector dbias_offsets(num_tensors + 1, 0); for (size_t t = 0; t < num_tensors; ++t) { const size_t M = first_dims_h[t]; @@ -285,13 +287,13 @@ void performTest(const ProcessingMethod processing_method, rowwise_sfs_num += rowwise_sfs; colwise_sfs_num += colwise_sfs; - + sum_of_last_dims += K; + rowwise_scales_offset[t+1] = rowwise_sfs_num; colwise_scales_offset[t+1] = colwise_sfs_num; + dbias_offsets[t+1] = sum_of_last_dims; } - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); - std::vector scales_rowwise_shape = {rowwise_sfs_num}; std::vector scales_colwise_shape = {colwise_sfs_num}; @@ -311,7 +313,7 @@ void performTest(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? colwise_sfs_num : 0); - std::vector ref_output_dbias(is_single_tensor ? cols : 0); + std::vector ref_output_dbias(sum_of_last_dims, static_cast(0.0f)); for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); @@ -336,6 +338,7 @@ void performTest(const ProcessingMethod processing_method, const size_t in_data_size = elts_num * sizeof(InputType); const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t dbias_data_size = sum_of_last_dims * sizeof(InputType); const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0); const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0); @@ -345,6 +348,7 @@ void performTest(const ProcessingMethod processing_method, InputType* grad_data_d; InputType* in_data_d; + InputType* dbias_out_data_d; OutputType* out_data_rowwise_d; OutputType* out_data_colwise_d; fp8e8m0* out_scales_rowwise_d; @@ -366,6 +370,10 @@ void performTest(const ProcessingMethod processing_method, cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + std::vector dbias_logical_shape_vec= {num_tensors, cols}; + NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), + dbias_logical_shape_vec.size()); NVTEShape first_dims_shape_; NVTEShape last_dims_shape_; @@ -382,6 +390,7 @@ void performTest(const ProcessingMethod processing_method, NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor output_dbias_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, dbias_logical_shape_); NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; @@ -453,52 +462,40 @@ void performTest(const ProcessingMethod processing_method, &out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor)); } - Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + if (compute_dbias) { + cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); + cudaMemset(dbias_out_data_d, 0, dbias_data_size); + NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; + nvte_set_grouped_tensor_param(&output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor); + } // Reference (CPU) - if (is_single_tensor) { - - const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32); - const size_t unpadded_colwise_blocks_X = cols; + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; - const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4); - const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128); + const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; + const size_t scales_stride_colwise = colwise_scales_last_dim[t]; + const size_t data_offset = offsets_h[t]; + const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; + const size_t colwise_sfs_offset = colwise_scales_offset[t]; + const size_t dbias_offset = dbias_offsets[t]; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; + InputType* const ref_output_dbias_ptr = ref_output_dbias.data() + dbias_offset; compute_ref( - processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(), - out_data_rowwise_ref.data(), out_data_colwise_ref.data(), - out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), - ref_output_dbias.data(), rows, cols, + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias_ptr, M, K, scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } else { - for (size_t t = 0; t < num_tensors; ++t) { - const size_t M = first_dims_h[t]; - const size_t K = last_dims_h[t]; - - const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; - const size_t scales_stride_colwise = colwise_scales_last_dim[t]; - const size_t data_offset = offsets_h[t]; - const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; - const size_t colwise_sfs_offset = colwise_scales_offset[t]; - - const InputType* const grad_ptr = grad_data.data() + data_offset; - const InputType* const in_ptr = in_data.data() + data_offset; - OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; - OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; - fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; - fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; - - compute_ref( - processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, - out_data_rowwise_ptr, out_data_colwise_ptr, - out_scales_rowwise_ptr, out_scales_colwise_ptr, - ref_output_dbias.data(), M, K, - scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } + scales_stride_colwise); } // GPU @@ -509,9 +506,9 @@ void performTest(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS: { - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_DBIAS_DACT: { @@ -522,10 +519,10 @@ void performTest(const ProcessingMethod processing_method, else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -586,9 +583,10 @@ void performTest(const ProcessingMethod processing_method, out_data_colwise_h.data(), rows, cols, false, mismatches_elts); } - if (processing_method == ProcessingMethod::CAST_DBIAS - || processing_method == ProcessingMethod::CAST_DBIAS_DACT) - { + if (compute_dbias) { + Tensor output_dbias("output_dbias", std::vector{ sum_of_last_dims }, itype); + cudaMemcpy(output_dbias.rowwise_dptr(), dbias_out_data_d, dbias_data_size, cudaMemcpyDeviceToDevice); + auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -601,6 +599,7 @@ void performTest(const ProcessingMethod processing_method, cudaFree(grad_data_d); cudaFree(in_data_d); + cudaFree(dbias_out_data_d); cudaFree(first_dims_d); cudaFree(last_dims_d); cudaFree(offsets_d); @@ -648,7 +647,6 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 2, 384,160, 128,256}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index d209ea8d47..ea864813bf 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -32,7 +32,7 @@ void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dqgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index b6f758caf6..fc9122b7ec 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -32,7 +32,7 @@ void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_drelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_drelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsrelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 77d5b6867f..12478af4cf 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -32,7 +32,7 @@ void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsilu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 57404ae8a5..4f9ddb4fc5 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d } void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 0997b01f7e..24a7e7fa79 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -22,6 +22,14 @@ namespace transformer_engine { namespace dispatch { namespace common { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); @@ -78,6 +86,61 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } stg_vec.store_to(thread_out_base); } + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t *const offsets_ptr, + const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, + OType *const dbias_output, + const float *dbias_partial, + const size_t chunk_dim_Y) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t tensor_id = blockIdx.y; + const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (first_logical_dim / num_tensors) + : first_dims_ptr[tensor_id]; + + const size_t rows = tensor_rows / chunk_dim_Y; + const size_t cols = last_logical_dim; + + const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec; + OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} } // namespace kernel template @@ -96,6 +159,37 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t *const data_tensor_offsets_ptr, + const int64_t *const data_tensor_first_dims_ptr, + const int64_t *const data_tensor_last_dims_ptr, + GroupedTensor *dbias, + const float *workspace_ptr, + const size_t chunk_dim_Y, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(last_logical_dim % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t blocks_X = DIVUP(last_logical_dim, THREADS_PER_BLOCK * reduce_dbias_nvec); + const size_t blocks_Y = num_tensors; + const dim3 grid(blocks_X, blocks_Y); + + group_reduce_dbias_kernel + <<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + data_tensor_offsets_ptr, data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace common } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 98a3fb8cba..f7823b4c58 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -382,13 +382,13 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); const NVTEGroupedTensor activation = nullptr; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config @@ -419,8 +419,9 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor template void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); @@ -428,7 +429,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 6447fc4542..eac8f37387 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -28,19 +28,14 @@ namespace dispatch { namespace mxfp8 { namespace group_quantize_kernel { +using namespace dispatch::common; + constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; __device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -144,11 +139,14 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te if constexpr (is_blackwell) { const size_t global_stride_bytes = global_dim_X * data_type_size_bytes; if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) { - NVTE_DEVICE_ERROR("Shape not supported, as data stride must be 16B aligned."); + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be 16B aligned."); } if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } + if (global_dim_X % CHUNK_DIM_X != 0) { + NVTE_DEVICE_ERROR("The grouped tensor must be divisible by 128x128 tiles without a tail tile."); + } asm volatile( "{\n\t" @@ -782,8 +780,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel template void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, - const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { + const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, + Tensor *workspace, cudaStream_t stream) { using namespace group_quantize_kernel; checkCuDriverContext(stream); @@ -834,23 +832,14 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks_X = 0; - size_t blocks_Y = 0; - - if (is_single_tensor) { - blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); - } else { - NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + if (!is_single_tensor) { + NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - // Only full tiles supported - NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0, - "Last dimension of a grouped tensor should be divisible by 128."); - blocks_Y = 1; - blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } - const dim3 grid(blocks_X, blocks_Y); + + NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); + const dim3 grid(elts_total / ELTS_PER_CHUNK); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -879,18 +868,20 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input_tensor."); - NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + std::vector expected_shape_dbias_tensor = {num_tensors, last_logical_dim}; + NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); + + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_cols = last_logical_dim; if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; workspace->data.dtype = DType::kFloat32; return; } @@ -1006,9 +997,12 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, + dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 04712d3003..88c483d400 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -161,7 +161,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d * \param[in] stream CUDA stream used for the operation. */ void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. @@ -207,7 +207,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. @@ -253,7 +253,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. @@ -299,7 +299,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. @@ -345,7 +345,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. @@ -391,7 +391,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, From 7abbc7b83a5f500882896b668b24704b98187a8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:30:57 +0000 Subject: [PATCH 02/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 4 +- .../common/cast/core/common.cuh | 46 ++++++++----------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 8 ++-- .../common/include/transformer_engine/cast.h | 15 ++++-- 4 files changed, 34 insertions(+), 39 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 29a02124af..5352a2c91a 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -288,7 +288,7 @@ void performTest(const ProcessingMethod processing_method, rowwise_sfs_num += rowwise_sfs; colwise_sfs_num += colwise_sfs; sum_of_last_dims += K; - + rowwise_scales_offset[t+1] = rowwise_sfs_num; colwise_scales_offset[t+1] = colwise_sfs_num; dbias_offsets[t+1] = sum_of_last_dims; @@ -370,7 +370,7 @@ void performTest(const ProcessingMethod processing_method, cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); - + std::vector dbias_logical_shape_vec= {num_tensors, cols}; NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), dbias_logical_shape_vec.size()); diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 24a7e7fa79..a4e033939b 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -89,30 +89,25 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) template __global__ void __launch_bounds__(THREADS_PER_BLOCK) - group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, - const int64_t *const offsets_ptr, - const int64_t *const first_dims_ptr, - const int64_t *const last_dims_ptr, - OType *const dbias_output, - const float *dbias_partial, - const size_t chunk_dim_Y) { + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, OType *const dbias_output, + const float *dbias_partial, const size_t chunk_dim_Y) { using ComputeVec = Vec; using OutputVec = Vec; const size_t tensor_id = blockIdx.y; const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (first_logical_dim / num_tensors) - : first_dims_ptr[tensor_id]; - + ? (first_logical_dim / num_tensors) + : first_dims_ptr[tensor_id]; + const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; @@ -160,16 +155,12 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, } template -void grouped_reduce_dbias(const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, const int64_t *const data_tensor_offsets_ptr, const int64_t *const data_tensor_first_dims_ptr, - const int64_t *const data_tensor_last_dims_ptr, - GroupedTensor *dbias, - const float *workspace_ptr, - const size_t chunk_dim_Y, + const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias, + const float *workspace_ptr, const size_t chunk_dim_Y, cudaStream_t stream) { using namespace kernel; constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 @@ -181,11 +172,10 @@ void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t blocks_Y = num_tensors; const dim3 grid(blocks_X, blocks_Y); - group_reduce_dbias_kernel - <<>>( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - data_tensor_offsets_ptr, data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, - reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + group_reduce_dbias_kernel<<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr, + data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index eac8f37387..e8d30d64fd 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -145,7 +145,8 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } if (global_dim_X % CHUNK_DIM_X != 0) { - NVTE_DEVICE_ERROR("The grouped tensor must be divisible by 128x128 tiles without a tail tile."); + NVTE_DEVICE_ERROR( + "The grouped tensor must be divisible by 128x128 tiles without a tail tile."); } asm volatile( @@ -999,9 +1000,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations if constexpr (IS_DBIAS) { common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, - dbias, workspace_ptr, CHUNK_DIM_Y, stream); + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 88c483d400..95d01fd8bf 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -207,7 +207,8 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. @@ -253,7 +254,8 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. @@ -299,7 +301,8 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. @@ -345,7 +348,8 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. @@ -391,7 +395,8 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, From f820b21b1b86fa690b42a5f3f73ecb0fd448f2cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:20:33 +0000 Subject: [PATCH 03/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index e8d30d64fd..343c783bf4 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -998,11 +998,11 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - if constexpr (IS_DBIAS) { - common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, - first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); - } + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) From 0c056325379de6045ada06d12b29ab42b4a43171 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Feb 2026 14:12:36 +0000 Subject: [PATCH 04/31] Relaxed constraints on the last dimension Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 19 ++++++++++----- .../cast/mxfp8/group_quantize_mxfp8.cuh | 23 +++++++++++-------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 5352a2c91a..7a93f504cf 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -647,6 +647,8 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, @@ -712,26 +714,31 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { } } offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; - // Skips tests if tensor shape is not as required by the kernel - if (first_dims[t] % 128 != 0) { + // Skip tests when the tensor shape is incompatible with the kernel. + // The TMA engine requires strides to be 16-byte aligned. + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 16 != 0)) { GTEST_SKIP(); } - if (!is_single_tensor && (last_dims[t] % 128 != 0)) { + // If a grouped tensor has a varying last dimension, it must be a multiple of 128. + // Otherwise, computing the grid size adds runtime overhead in the non-persistent kernel, + // since the relevant tensor metadata resides in device memory. + constexpr size_t CHUNK_DIM_X = 128; + if (!is_single_tensor && (last_dims[t] % CHUNK_DIM_X != 0)) { GTEST_SKIP(); } } - // Skips DBias tests if last dimension of tensors variates + // Skip dBias tests when tensors in the group have different last dimensions. if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) && !is_single_tensor) { GTEST_SKIP(); } - // Skips non Act tests if the Activation type is not an identity + // Skip non-activation tests when the activation type is not Identity. if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) && activation != ActivationKind::Identity) { GTEST_SKIP(); } - // Skips Act tests if the Activation is an identity + // Skip activation tests when the activation type is Identity. if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT || processing_method == ProcessingMethod::CAST_DACT || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 343c783bf4..dc03df0aac 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -106,6 +106,9 @@ __device__ __forceinline__ size_t get_tensor_rows_num( rows_num = static_cast(first_dims_ptr[tensor_id]); break; } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } return rows_num; } @@ -144,10 +147,6 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } - if (global_dim_X % CHUNK_DIM_X != 0) { - NVTE_DEVICE_ERROR( - "The grouped tensor must be divisible by 128x128 tiles without a tail tile."); - } asm volatile( "{\n\t" @@ -833,22 +832,28 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - if (!is_single_tensor) { + size_t blocks = 0; + if (is_single_tensor) { + const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + blocks = blocks_Y * blocks_X; + } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); + // Only full tiles supported + NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); + blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } - - NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); - const dim3 grid(elts_total / ELTS_PER_CHUNK); const size_t block_size = THREADS_PER_CHUNK; + const dim3 grid(blocks); const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; // Logical shape of a tensor with varying all dims is [1, M*K] if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { NVTE_CHECK(first_logical_dim % 128 == 0, - "First dimension of a grouped tensor should be divisible by 128."); + "First logical dimension of a grouped tensor must be divisible by 128."); } const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); From 4a85deade069ada724516c8769bc1a98f5c5d461 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Feb 2026 14:28:33 +0000 Subject: [PATCH 05/31] Added notes on group tensor restrictions into documentation Signed-off-by: Oleg Goncharov --- .../include/transformer_engine/activation.h | 10 ++++++++ .../common/include/transformer_engine/cast.h | 25 ++++++++++++------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 06f1c65ce2..854f52c203 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -56,6 +56,7 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the GeLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -76,6 +77,7 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the SiLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -96,6 +98,7 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the ReLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -116,6 +119,7 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the Quick GeLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -136,6 +140,7 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the Squared ReLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -158,6 +163,7 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the GeLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -182,6 +188,7 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the SiLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -206,6 +213,7 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the ReLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -230,6 +238,7 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu /*! \brief Computes the Quick GeLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -254,6 +263,7 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu /*! \brief Computes the Squared ReLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 95d01fd8bf..755052d6dd 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -92,6 +92,7 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea /*! \brief Casts input grouped tensor to MXFP8. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor to be cast. * \param[in,out] output Output grouped MXFP8 tensor. @@ -146,6 +147,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d /*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -190,6 +192,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -237,6 +240,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -284,6 +288,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -331,6 +336,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -378,6 +384,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -412,11 +419,11 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str /*! \brief Casts multiple input tensors to quantized output tensors. * - * \param[in] inputs List of input tensors to be cast. - * \param[in,out] outputs List of output quantized tensors. - * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] num_tensors Number of input and output tensors. - * \param[in] stream CUDA stream used for the operation. + * \param[in] inputs List of input tensors to be cast. + * \param[in,out] outputs List of output quantized tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_config, const size_t num_tensors, @@ -425,11 +432,11 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, /*! \brief Casts grouped input tensor to quantized output tensors. * * \param[in] input Input tensor to be cast. - * \param[in,out] outputs Output quantized tensors. - * \param[in] split_sections Split sections of the input tensor. - * \param[in] num_tensors Number of output tensors. + * \param[in,out] outputs Output quantized tensors. + * \param[in] split_sections Split sections of the input tensor. + * \param[in] num_tensors Number of output tensors. * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] stream CUDA stream used for the operation. + * \param[in] stream CUDA stream used for the operation. */ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, size_t num_tensors, From aedd53dbcd7f2a895d045501b2c6074fec2df79d Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 27 Feb 2026 15:49:04 +0000 Subject: [PATCH 06/31] Fixes per the review Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 20 +++++++++---------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 2 -- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 7a93f504cf..63fbdff627 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -346,16 +346,16 @@ void performTest(const ProcessingMethod processing_method, const size_t last_dims_size = num_tensors * sizeof(size_t); const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); - InputType* grad_data_d; - InputType* in_data_d; - InputType* dbias_out_data_d; - OutputType* out_data_rowwise_d; - OutputType* out_data_colwise_d; - fp8e8m0* out_scales_rowwise_d; - fp8e8m0* out_scales_colwise_d; - size_t* first_dims_d; - size_t* last_dims_d; - size_t* offsets_d; + InputType* grad_data_d = nullptr; + InputType* in_data_d = nullptr; + InputType* dbias_out_data_d = nullptr; + OutputType* out_data_rowwise_d = nullptr; + OutputType* out_data_colwise_d = nullptr; + fp8e8m0* out_scales_rowwise_d = nullptr; + fp8e8m0* out_scales_colwise_d = nullptr; + size_t* first_dims_d = nullptr; + size_t* last_dims_d = nullptr; + size_t* offsets_d = nullptr; cudaMalloc((void**)&grad_data_d, in_data_size); cudaMalloc((void**)&in_data_d, in_data_size); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index dc03df0aac..c9cc73974b 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -841,8 +841,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - // Only full tiles supported - NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } const size_t block_size = THREADS_PER_CHUNK; From 38288b184cf5f7c7f68e7620f47265a55d626c8d Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 27 Feb 2026 15:56:11 +0000 Subject: [PATCH 07/31] Fixed pointer Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 63fbdff627..204236da43 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -466,7 +466,7 @@ void performTest(const ProcessingMethod processing_method, cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); cudaMemset(dbias_out_data_d, 0, dbias_data_size); NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; - nvte_set_grouped_tensor_param(&output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor); + nvte_set_grouped_tensor_param(output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor); } // Reference (CPU) From ce3a13731f7fa72cb1ac4ab689255b26585442f5 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 27 Feb 2026 16:13:00 +0000 Subject: [PATCH 08/31] More fixes Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 204236da43..e469ad0845 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -466,7 +466,8 @@ void performTest(const ProcessingMethod processing_method, cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); cudaMemset(dbias_out_data_d, 0, dbias_data_size); NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; - nvte_set_grouped_tensor_param(output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor); + nvte_set_grouped_tensor_param(output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &output_dbias_data_tensor, sizeof(output_dbias_data_tensor)); } // Reference (CPU) From bddd804e4e61ed03954301b55c36f89b358909a0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 2 Mar 2026 20:14:10 +0000 Subject: [PATCH 09/31] Fixed kernel grid size Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index c9cc73974b..129d6724ac 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -832,19 +832,21 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks = 0; + size_t blocks_X = 0; + size_t blocks_Y = 0; + if (is_single_tensor) { - const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); - blocks = blocks_Y * blocks_X; + blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + blocks_Y = 1; + blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } + const dim3 grid(blocks_X, blocks_Y); const size_t block_size = THREADS_PER_CHUNK; - const dim3 grid(blocks); const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; From 87352bd488d25678a01c2e6f36838226348ad5e0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 13:28:06 +0000 Subject: [PATCH 10/31] Enabled persistency with WorkID Query feature Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 893 ++++++++++-------- 1 file changed, 507 insertions(+), 386 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..e0a1a1a814 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -39,7 +39,8 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_T constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t BUFFS_NUM = 2; +constexpr size_t PREFETCH_STAGES = 1; +constexpr size_t BUFFS_NUM = PREFETCH_STAGES + 1; constexpr size_t PACK_SIZE = 4; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; @@ -261,93 +262,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const size_t block_ID = blockIdx.y * gridDim.x + blockIdx.x; - const size_t block_global_offset = - is_single_tensor ? (blockIdx.y * CHUNK_DIM_Y * last_logical_dim + blockIdx.x * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); - - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, blockIdx.y, - first_logical_dim, last_logical_dim, offsets_ptr); - - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - - const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); - const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); - - // grouped tensor can be treated as continuous tensor for MXFP8 - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); - // For grouped tensors represented as a single logical tensor, scale swizzle must still be - // computed per tensor (expert) and then concatenated along dim-0. - const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) - ? static_cast(offsets_ptr[tensor_id]) - : tensor_base; - - // In graph-safe paged stashing, the logical shape can include trailing garbage. Skip CTAs that - // map outside the current tensor's valid [rows, cols] region. - if (rows == 0 || cols == 0) { - return; - } - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - return; - } - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - return; - } - } - - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = - is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = - is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; - const bool leading_thread = (threadIdx.x == 0); - if (leading_thread && (!is_single_tensor)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); - } - if constexpr (ROWWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_rowwise); - } - if constexpr (COLWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_colwise); - } - } - - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); - - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; - - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; - - e8m0_t *const scales_rowwise = - scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = - scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); - - const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X; const size_t tid_Y_colwise = 0; @@ -356,11 +272,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -390,374 +301,578 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } + constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; float block_amax = 0.0f; -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; + __shared__ uint64_t workID_mbar; + __shared__ __uint128_t workID_response; + constexpr uint32_t workID_response_size = sizeof(workID_response); + static_assert(workID_response_size == 16); - initialize_barriers(mbar, leading_thread); + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::mbarrier_init(&workID_mbar, 1); + ptx::fence_proxy_async_shared_cta(); } + __syncthreads(); + + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + int ctaid_parity = 0; + int32_t ctaid_X = blockIdx.x; + int32_t ctaid_Y = blockIdx.y; + bool job_finished = false; + int buff_in = 0; + + // Prefetch the first stage of the first job. + { + const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (block_ID * ELTS_PER_CHUNK); + + const size_t tensor_id = + get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + if (rows == 0 || cols == 0) { + return; + } + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); + if (block_global_offset >= tensor_end_offset) { + return; + } + const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; + if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { + return; + } + } + + const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t block_id_in_current_tensor = + is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + } #pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - leading_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buff_offset]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + barrier); + if constexpr (IS_DACT) { + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[buff_offset]), + reinterpret_cast(&tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); + } } } + } - ptx::fence_proxy_async_shared_cta(); + while (!job_finished) { + const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (block_ID * ELTS_PER_CHUNK); + const size_t tensor_id = + get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + bool current_job_is_valid = (rows != 0) && (cols != 0); + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); + if (block_global_offset >= tensor_end_offset) { + current_job_is_valid = false; + } + if (current_job_is_valid) { + const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; + if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { + current_job_is_valid = false; + } + } + } + if (!current_job_is_valid) { + // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); + break; + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); + const size_t block_id_in_current_tensor = + is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = + is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = + is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + } - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); #pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (stage == STAGES - PREFETCH_STAGES) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + ctaid_parity ^= 1; + } - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); + if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { + const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + + const size_t prefetch_block_ID = + static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t prefetch_block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (prefetch_block_ID * ELTS_PER_CHUNK); + const size_t prefetch_tensor_id = + get_current_tensor_id(shape_rep, num_tensors, prefetch_block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t prefetch_tensor_base = + is_single_tensor ? 0 : static_cast(offsets_ptr[prefetch_tensor_id]); + const size_t prefetch_cols = + get_tensor_cols_num(prefetch_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t prefetch_blocks_X_num_in_current_tensor = + DIVUP(prefetch_cols, static_cast(128)); + const size_t prefetch_block_id_in_current_tensor = + is_single_tensor ? prefetch_block_ID + : (prefetch_block_ID - prefetch_tensor_base / ELTS_PER_CHUNK); + const size_t prefetch_block_id_Y = + prefetch_block_id_in_current_tensor / prefetch_blocks_X_num_in_current_tensor; + const size_t prefetch_block_id_X = + prefetch_block_id_in_current_tensor % prefetch_blocks_X_num_in_current_tensor; + + const size_t global_offset_Y = prefetch_block_id_Y * CHUNK_DIM_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = prefetch_block_id_X * CHUNK_DIM_X; + const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; + + const CUtensorMap &prefetch_tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_tensor_id]; + const CUtensorMap &prefetch_tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_tensor_id]; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + if ((!is_single_tensor) && (stage == STAGES - PREFETCH_STAGES)) { + fence_acquire_tensormap(&prefetch_tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&prefetch_tensor_map_act_input); + } } + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_prefetch_buff_offset]), + reinterpret_cast(&prefetch_tensor_map_input), global_offset_X, + global_offset_Y, barrier); if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[next_prefetch_buff_offset]), + reinterpret_cast(&prefetch_tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_colwise[i] = elt; } + ptx::fence_proxy_async_shared_cta(); } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - const size_t tensor_base_row = tensor_base_for_scales / cols; - const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; - const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); - } else { - scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - } - scales_colwise[scale_idx] = biased_exponent; + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const size_t buff = buff_in; + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = tensor_scales_offset_colwise_base + + gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if constexpr (std::is_same_v) { + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + Vec in_IType[WAVES]; + + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); #pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } } - } - if constexpr (!std::is_same_v) { thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; } } - } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, - DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - } - scales_rowwise[scale_idx] = biased_exponent; + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -// 3. Scale elements #pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; + for (int w = 0; w < WAVES; ++w) { + Vec out; #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); - // Initiate TMA transfer to copy shared memory to global memory - if (leading_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + ptx::cp_async_bulk_commit_group(); } - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); + buff_in = (buff_in + 1) % BUFFS_NUM; } - } - - parity ^= 1; - if constexpr (IS_DBIAS) { - if (is_single_tensor) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + for (int i = 0; i < THREADS_Y; ++i) { + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; } - } - const int dbias_stride = cols; - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; } } } @@ -772,7 +887,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, leading_thread); + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + ptx::mbarrier_invalid(&workID_mbar); + } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace group_quantize_kernel From e23f553cdb7e90e183128f157f53fc2332c35bdd Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 14:03:59 +0000 Subject: [PATCH 11/31] Added a struct with tunable parameters Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index e0a1a1a814..31e8645e07 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -36,17 +36,26 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +struct TunableConfig { + static constexpr size_t CHUNK_DIM_Y = 128; + static constexpr size_t CHUNK_DIM_X = 128; + static constexpr size_t THREADS_PER_CHUNK = 128; + static constexpr size_t PREFETCH_STAGES = 1; + // Set false to run one-CTA-per-block (non-persistent) mode. + static constexpr bool PERSISTENT = true; +}; + constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t PREFETCH_STAGES = 1; +constexpr size_t PREFETCH_STAGES = TunableConfig::PREFETCH_STAGES; constexpr size_t BUFFS_NUM = PREFETCH_STAGES + 1; constexpr size_t PACK_SIZE = 4; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr size_t CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK = TunableConfig::THREADS_PER_CHUNK; constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; @@ -511,9 +520,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); - ptx::try_cancel_cta(&workID_mbar, &workID_response); + if constexpr (TunableConfig::PERSISTENT) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } } #pragma unroll @@ -521,12 +532,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (stage == STAGES - PREFETCH_STAGES) { - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); - ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + if constexpr (TunableConfig::PERSISTENT) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + ctaid_parity ^= 1; + } else { + ctaid_X = -1; + ctaid_Y = -1; + } if (ctaid_X == -1 && ctaid_Y == -1) { job_finished = true; } - ctaid_parity ^= 1; } if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { From d185299e8eb28a0b042be91a1741f301bf116630 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 15:15:14 +0000 Subject: [PATCH 12/31] Added persistency with static scheduling Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 127 +++++++++++++----- 1 file changed, 96 insertions(+), 31 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 31e8645e07..791fe65098 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -19,6 +19,7 @@ #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" +#include "../../util/cuda_runtime.h" #include "../../utils.cuh" #include "../core/common.cuh" #include "swizzle.cuh" @@ -36,15 +37,30 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +enum class PersistentStrategy : int { + NONE = 0, + DYNAMIC_WORK_STEALING = 1, + STATIC_GRID_STRIDE = 2, +}; + struct TunableConfig { static constexpr size_t CHUNK_DIM_Y = 128; static constexpr size_t CHUNK_DIM_X = 128; static constexpr size_t THREADS_PER_CHUNK = 128; static constexpr size_t PREFETCH_STAGES = 1; - // Set false to run one-CTA-per-block (non-persistent) mode. - static constexpr bool PERSISTENT = true; + static constexpr PersistentStrategy PERSISTENT_STRATEGY = + PersistentStrategy::STATIC_GRID_STRIDE; + // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 1; }; +constexpr bool DYNAMIC_PERSISTENT = + TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; +constexpr bool STATIC_PERSISTENT = + TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; +static_assert(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0, + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero."); + constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -251,7 +267,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, - float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, + const size_t work_blocks_X, const size_t work_blocks_Y) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -315,8 +332,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel float block_amax = 0.0f; __shared__ uint64_t workID_mbar; - __shared__ __uint128_t workID_response; - constexpr uint32_t workID_response_size = sizeof(workID_response); + [[maybe_unused]] __shared__ __uint128_t workID_response; + [[maybe_unused]] constexpr uint32_t workID_response_size = sizeof(workID_response); static_assert(workID_response_size == 16); __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; @@ -331,16 +348,32 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } __syncthreads(); + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + const size_t launch_block_id = + static_cast(blockIdx.y) * static_cast(gridDim.x) + static_cast(blockIdx.x); + int IN_buff_readable_parity[BUFFS_NUM] = {0}; - int ctaid_parity = 0; - int32_t ctaid_X = blockIdx.x; - int32_t ctaid_Y = blockIdx.y; + [[maybe_unused]] int ctaid_parity = 0; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + [[maybe_unused]] size_t static_next_block_id = 0; + [[maybe_unused]] size_t static_block_stride = 0; + if constexpr (STATIC_PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = static_cast(gridDim.x) * static_cast(gridDim.y); + static_next_block_id = launch_block_id + static_block_stride; + } bool job_finished = false; int buff_in = 0; + bool has_prefetched_current_job = true; // Prefetch the first stage of the first job. { - const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); const size_t block_global_offset = is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + static_cast(ctaid_X) * CHUNK_DIM_X) @@ -352,7 +385,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - if (rows == 0 || cols == 0) { + if (block_ID >= total_work_blocks || rows == 0 || cols == 0) { return; } if (shape_rep != SAME_BOTH_DIMS) { @@ -415,7 +448,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } while (!job_finished) { - const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); const size_t block_global_offset = is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + static_cast(ctaid_X) * CHUNK_DIM_X) @@ -428,7 +461,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - bool current_job_is_valid = (rows != 0) && (cols != 0); + bool current_job_is_valid = (block_ID < total_work_blocks) && (rows != 0) && (cols != 0); if (shape_rep != SAME_BOTH_DIMS) { const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); @@ -445,11 +478,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } if (!current_job_is_valid) { - // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); - IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); + if (has_prefetched_current_job) { + // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); + } break; } @@ -520,38 +555,56 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if constexpr (TunableConfig::PERSISTENT) { + if constexpr (DYNAMIC_PERSISTENT) { if (leading_thread) { ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); ptx::try_cancel_cta(&workID_mbar, &workID_response); } } + bool prefetched_next_job = false; #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; + bool allow_next_job_prefetch = true; if (stage == STAGES - PREFETCH_STAGES) { - if constexpr (TunableConfig::PERSISTENT) { + if constexpr (DYNAMIC_PERSISTENT) { ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); ctaid_parity ^= 1; + } else if constexpr (STATIC_PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; + } else { + // Next loop iteration exits via current_job_is_valid check. + ctaid_X = 0; + ctaid_Y = static_cast(work_blocks_Y); + allow_next_job_prefetch = false; + } } else { ctaid_X = -1; ctaid_Y = -1; } - if (ctaid_X == -1 && ctaid_Y == -1) { - job_finished = true; + if constexpr (!STATIC_PERSISTENT) { + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } } } - if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + if (stage >= STAGES - PREFETCH_STAGES) { + prefetched_next_job = true; + } const size_t prefetch_block_ID = - static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); const size_t prefetch_block_global_offset = is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + static_cast(ctaid_X) * CHUNK_DIM_X) @@ -851,6 +904,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel buff_in = (buff_in + 1) % BUFFS_NUM; } + has_prefetched_current_job = prefetched_next_job; if constexpr (IS_DBIAS) { if (is_single_tensor) { @@ -969,20 +1023,30 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks_X = 0; - size_t blocks_Y = 0; + size_t work_blocks_X = 0; + size_t work_blocks_Y = 0; if (is_single_tensor) { - blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + work_blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + work_blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - blocks_Y = 1; - blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + work_blocks_Y = 1; + work_blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + } + + size_t launch_blocks_X = work_blocks_X; + size_t launch_blocks_Y = work_blocks_Y; + if constexpr (STATIC_PERSISTENT) { + const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); + const size_t static_grid_size = sm_num * TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; + NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); + launch_blocks_X = static_grid_size; + launch_blocks_Y = 1; } - const dim3 grid(blocks_X, blocks_Y); + const dim3 grid(launch_blocks_X, launch_blocks_Y); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -1138,7 +1202,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, + work_blocks_Y); if constexpr (IS_DBIAS) { common::grouped_reduce_dbias( From 5e15f574375e2b06b9ade6dff4479d69ea842f2a Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 16:15:42 +0000 Subject: [PATCH 13/31] Fixed test cases Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index e469ad0845..6cff159d51 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -554,6 +554,11 @@ void performTest(const ProcessingMethod processing_method, const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; + // Compare only allocated contiguous output range. + // In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back(). + const size_t compare_rows = 1; + const size_t compare_cols = elts_num; + if (rowwise) { cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost); @@ -566,7 +571,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), - out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + out_data_rowwise_h.data(), compare_rows, compare_cols, + true, mismatches_elts); } if (colwise) { @@ -581,7 +587,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), - out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + out_data_colwise_h.data(), compare_rows, compare_cols, + false, mismatches_elts); } if (compute_dbias) { @@ -616,15 +623,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - ProcessingMethod::CAST_DBIAS, - ProcessingMethod::CAST_DBIAS_DACT, - ProcessingMethod::CAST_DACT, - ProcessingMethod::CAST_ACT, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { ActivationKind::Identity, - ActivationKind::GeLU, + // ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -639,21 +646,23 @@ enum ScalingDirection { std::vector scaling_directions = { ScalingDirection::ROWWISE, - ScalingDirection::COLWISE, - ScalingDirection::BOTH, + // ScalingDirection::COLWISE, + // ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - {SAME_BOTH_DIMS, 1, 128,128}, - {SAME_BOTH_DIMS, 2, 256,128}, - {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, - {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + // {SAME_BOTH_DIMS, 1, 128,128}, + // {SAME_BOTH_DIMS, 2, 256,128}, + // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 4096,4096, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,4096, 128,256,384,1024,2304}, + // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -815,8 +824,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kBFloat16), + ::testing::Values(DType::kFloat8E4M3)), + // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); From 98e95582bcc97cea7a8c5e3430917dcf77df0543 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 16:18:07 +0000 Subject: [PATCH 14/31] Ready for benchmarking Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 56 +- .../common/activation/activation_template.h | 30 +- .../common/cast/dispatch/dequantize.cuh | 52 +- .../common/cast/dispatch/gated.cuh | 304 ++++---- .../common/cast/dispatch/quantize.cuh | 729 +++++++++--------- 6 files changed, 589 insertions(+), 585 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 6f4f163f08..2092975b2a 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,7 +8,8 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + set(CMAKE_CUDA_ARCHITECTURES 100) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 56880a428d..a04cc3c38c 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,35 +3,35 @@ # See LICENSE for license information. add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu + # test_cast.cu + # test_cast_current_scaling.cu + # test_cast_dbias.cu + # test_cast_dbias_dgelu.cu + # test_cast_gated_swiglu.cu + # test_cast_mxfp8_gated_swiglu.cu + # test_qdq.cu + # test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu - test_grouped_gemm.cu + # test_cast_nvfp4_transpose.cu + # test_cast_float8blockwise.cu + # test_dequantize_mxfp8.cu + # test_transpose.cu + # test_cast_transpose.cu + # test_cast_transpose_current_scaling.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + # test_normalization.cu + # test_normalization_mxfp8.cu + # test_memset.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_multi_unpadding.cu + # test_causal_softmax.cu + # test_swizzle.cu + # test_swap_first_dims.cu + # test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index ffbffafd1a..caf6cbda65 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -22,36 +22,36 @@ namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - using namespace detail; - constexpr bool IS_ACT = true; - dispatch::quantize_fwd_helper(input, output, nullptr, stream); + // using namespace detail; + // constexpr bool IS_ACT = true; + // dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { - using namespace detail; - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - - dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, - nullptr, stream); + // using namespace detail; + // constexpr bool IS_DBIAS = false; + // constexpr bool IS_DACT = true; + // constexpr NVTETensor dbias = nullptr; + // constexpr NVTETensor workspace = nullptr; + + // dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + // nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - using namespace detail; - dispatch::quantize_gated_fwd_helper(input, output, p, stream); + // using namespace detail; + // dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - using namespace detail; - dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); + // using namespace detail; + // dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 81304981d3..db2ad285a8 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -22,32 +22,32 @@ namespace transformer_engine { namespace dispatch { inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - switch (input.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); - NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); - fp8::dequantize(input, output, stream); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { - mxfp8::dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - break; - } - case NVTE_NVFP4_1D_SCALING: { - nvfp4::dequantize(input, output, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } + // CheckInputTensor(input, "cast_input"); + // CheckOutputTensor(*output, "cast_output"); + + // switch (input.scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); + // NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + // NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + // fp8::dequantize(input, output, stream); + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // if (is_supported_by_CC_100()) { + // mxfp8::dequantize(input, output, stream); + // } else { + // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + // } + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // nvfp4::dequantize(input, output, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + // } } } // namespace dispatch diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..c2087533a6 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -25,164 +25,164 @@ namespace dispatch { template void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - const Tensor input = *convertNVTETensorCheck(nvte_input); - Tensor *output = convertNVTETensorCheck(nvte_output); - - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim() / 2; - - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == cols, - "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", - output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - if (use_tma_kernels) { - Tensor dummy_grad_tensor; - fp8::cast_gated_tma(input, dummy_grad_tensor, - output, p, stream); - } else { - fp8::cast_gated_fwd(input, output, p, stream); - } - if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // FP8 kernel only populates row-wise data, so perform - // transpose separately if needed - Tensor transpose_in, transpose_out, dummy; - transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_in.data.dptr = output->data.dptr; - transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - transpose_in.data.dtype = output->data.dtype; - transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_out.data.dptr = output->columnwise_data.dptr; - transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - transpose_out.data.dtype = output->data.dtype; - detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - NVTE_CHECK(cols % 32 == 0, - "Invalid input shape. Expected the last dimension to be " - "divisible by 32, but got ", - cols, "."); - if (output->has_data()) { - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "The type of the output tensor should be FP8."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - "The type of the columnwise output tensor should be FP8."); - } - NVTE_CHECK(is_supported_by_CC_100(), - "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - Tensor dummy_grad_tensor; - mxfp8::quantize_gated(input, dummy_grad_tensor, - output, p, stream); - break; - } - default: - NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - } + // const Tensor input = *convertNVTETensorCheck(nvte_input); + // Tensor *output = convertNVTETensorCheck(nvte_output); + + // CheckInputTensor(input, "input"); + // CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + // const size_t rows = input.flat_first_dim(); + // const size_t cols = input.flat_last_dim() / 2; + + // NVTE_CHECK(input.flat_last_dim() % 2 == 0, + // "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + // input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + // NVTE_CHECK(output->flat_last_dim() == cols, + // "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // switch (output->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // if (use_tma_kernels) { + // Tensor dummy_grad_tensor; + // fp8::cast_gated_tma(input, dummy_grad_tensor, + // output, p, stream); + // } else { + // fp8::cast_gated_fwd(input, output, p, stream); + // } + // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // // FP8 kernel only populates row-wise data, so perform + // // transpose separately if needed + // Tensor transpose_in, transpose_out, dummy; + // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_in.data.dptr = output->data.dptr; + // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + // transpose_in.data.dtype = output->data.dtype; + // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_out.data.dptr = output->columnwise_data.dptr; + // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + // transpose_out.data.dtype = output->data.dtype; + // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // NVTE_CHECK(cols % 32 == 0, + // "Invalid input shape. Expected the last dimension to be " + // "divisible by 32, but got ", + // cols, "."); + // if (output->has_data()) { + // NVTE_CHECK(is_fp8_dtype(output->data.dtype), + // "The type of the output tensor should be FP8."); + // } + // if (output->has_columnwise_data()) { + // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + // "The type of the columnwise output tensor should be FP8."); + // } + // NVTE_CHECK(is_supported_by_CC_100(), + // "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + // Tensor dummy_grad_tensor; + // mxfp8::quantize_gated(input, dummy_grad_tensor, + // output, p, stream); + // break; + // } + // default: + // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + // } } template void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); - const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); - Tensor *output = convertNVTETensorCheck(nvte_output); - - CheckInputTensor(grad, "grad"); - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", - gated_input.flat_last_dim(), "."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - - NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); - NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); - - NVTE_CHECK(grad.flat_first_dim() == rows, - "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", - grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - NVTE_CHECK(grad.flat_last_dim() == cols, - "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", - grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", - rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == cols * 2, - "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", - output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(gated_input.shape() == output->shape(), - "Gated input and output shapes must match. Input shape: ", gated_input.shape(), - ", output shape: ", output->shape(), "."); - - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - if (use_tma_kernels) { - fp8::cast_gated_tma(gated_input, grad, output, p, - stream); - } else { - fp8::cast_gated_bwd(gated_input, grad, output, p, stream); - } - if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // FP8 kernel only populates row-wise data, so perform - // transpose separately if needed - Tensor transpose_in, transpose_out, dummy; - transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_in.data.dptr = output->data.dptr; - transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - transpose_in.data.dtype = output->data.dtype; - transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_out.data.dptr = output->columnwise_data.dptr; - transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - transpose_out.data.dtype = output->data.dtype; - detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - NVTE_CHECK(cols % 32 == 0, - "Invalid input shape. Expected the last dimension to be " - "divisible by 32, but got ", - cols, "."); - if (output->has_data()) { - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "The type of the output tensor should be FP8."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - "The type of the columnwise output tensor should be FP8."); - } - NVTE_CHECK(is_supported_by_CC_100(), - "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - - mxfp8::quantize_gated(gated_input, grad, output, p, - stream); - break; - } - default: - NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - } + // const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + // const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + // Tensor *output = convertNVTETensorCheck(nvte_output); + + // CheckInputTensor(grad, "grad"); + // CheckInputTensor(gated_input, "gated_input"); + // CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + // NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + // gated_input.flat_last_dim(), "."); + + // const size_t rows = gated_input.flat_first_dim(); + // const size_t cols = gated_input.flat_last_dim() / 2; + + // NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); + // NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); + + // NVTE_CHECK(grad.flat_first_dim() == rows, + // "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + // NVTE_CHECK(grad.flat_last_dim() == cols, + // "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + // rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + // NVTE_CHECK(output->flat_last_dim() == cols * 2, + // "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + // NVTE_CHECK(gated_input.shape() == output->shape(), + // "Gated input and output shapes must match. Input shape: ", gated_input.shape(), + // ", output shape: ", output->shape(), "."); + + // switch (output->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // if (use_tma_kernels) { + // fp8::cast_gated_tma(gated_input, grad, output, p, + // stream); + // } else { + // fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + // } + // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // // FP8 kernel only populates row-wise data, so perform + // // transpose separately if needed + // Tensor transpose_in, transpose_out, dummy; + // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_in.data.dptr = output->data.dptr; + // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + // transpose_in.data.dtype = output->data.dtype; + // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_out.data.dptr = output->columnwise_data.dptr; + // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + // transpose_out.data.dtype = output->data.dtype; + // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // NVTE_CHECK(cols % 32 == 0, + // "Invalid input shape. Expected the last dimension to be " + // "divisible by 32, but got ", + // cols, "."); + // if (output->has_data()) { + // NVTE_CHECK(is_fp8_dtype(output->data.dtype), + // "The type of the output tensor should be FP8."); + // } + // if (output->has_columnwise_data()) { + // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + // "The type of the columnwise output tensor should be FP8."); + // } + // NVTE_CHECK(is_supported_by_CC_100(), + // "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + // mxfp8::quantize_gated(gated_input, grad, output, p, + // stream); + // break; + // } + // default: + // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + // } } } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f7823b4c58..0aadffa940 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -30,282 +30,282 @@ namespace dispatch { template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - Tensor *output_tensor = convertNVTETensorCheck(output); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - mxfp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // Tensor *output_tensor = convertNVTETensorCheck(output); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_ACT) { + // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // mxfp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/input_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } template void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *grad_tensor = convertNVTETensorCheck(grad); - const Tensor *input_tensor = convertNVTETensor(input); - - Tensor *output_tensor = convertNVTETensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT) { - cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*grad_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = grad_tensor->flat_first_dim(); - int32_t cols = grad_tensor->flat_last_dim(); - auto dtype = grad_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *grad_tensor = convertNVTETensorCheck(grad); + // const Tensor *input_tensor = convertNVTETensor(input); + + // Tensor *output_tensor = convertNVTETensorCheck(output); + // Tensor *dbias_tensor = convertNVTETensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_DBIAS && !IS_DACT) { + // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::quantize( + // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*grad_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = grad_tensor->flat_first_dim(); + // int32_t cols = grad_tensor->flat_last_dim(); + // auto dtype = grad_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -314,64 +314,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - std::vector output_tensors; - for (size_t i = 0; i < num_tensors; ++i) { - output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - } - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Take the scaling mode of the first output tensor - auto scaling_mode = output_tensors[0]->scaling_mode; - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - // Skip checking output tensor list - // output list here is allowed to have empty tensor - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "2D quantization is not supported for group quantize."); - - // Launch NVFP4 group quantize kernel - nvfp4::group_quantize_transpose( - *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - &quant_config_cpp, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // std::vector output_tensors; + // for (size_t i = 0; i < num_tensors; ++i) { + // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + // } + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // // Take the scaling mode of the first output tensor + // auto scaling_mode = output_tensors[0]->scaling_mode; + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // // Skip checking output tensor list + // // output list here is allowed to have empty tensor + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + + // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + // "2D quantization is not supported for group quantize."); + + // // Launch NVFP4 group quantize kernel + // nvfp4::group_quantize_transpose( + // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + // &quant_config_cpp, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } template @@ -407,7 +407,10 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( + // mxfp8::group_quantize( + // IS_ACT is set to false + // OP is set to nullptr + mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; @@ -422,40 +425,40 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); - GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( - grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + // const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + // const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + // GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + // GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::group_quantize( + // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } } // namespace dispatch From ab816cbbed6b8e3b9855e7365e46b6b34ec6f16a Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 17:41:17 +0000 Subject: [PATCH 15/31] Fixed out-of-boundary error Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 791fe65098..2cbcfe8218 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -51,7 +51,7 @@ struct TunableConfig { static constexpr PersistentStrategy PERSISTENT_STRATEGY = PersistentStrategy::STATIC_GRID_STRIDE; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). - static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 1; + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; constexpr bool DYNAMIC_PERSISTENT = @@ -595,6 +595,44 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } + // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. + // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. + if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { + const size_t next_block_ID = + static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); + const size_t next_block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (next_block_ID * ELTS_PER_CHUNK); + const size_t next_tensor_id = + get_current_tensor_id(shape_rep, num_tensors, next_block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t next_rows = + get_tensor_rows_num(next_tensor_id, shape_rep, first_logical_dim, first_dims_ptr, + num_tensors); + const size_t next_cols = + get_tensor_cols_num(next_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + bool next_job_is_valid = + (next_block_ID < total_work_blocks) && (next_rows != 0) && (next_cols != 0); + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[next_tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[next_tensor_id + 1]); + if (next_block_global_offset >= tensor_end_offset) { + next_job_is_valid = false; + } + if (next_job_is_valid) { + const size_t tensor_offset_from_start = next_block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / next_cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % next_cols; + if (block_offset_Y_in_tensor >= next_rows || block_offset_X_in_tensor >= next_cols) { + next_job_is_valid = false; + } + } + } + allow_next_job_prefetch = next_job_is_valid; + } + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; From 8a429ad8f3ae42fcb35ee452b0530ae4ecc70cb0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 17:59:35 +0000 Subject: [PATCH 16/31] Tuned kernel parameters Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 2cbcfe8218..26baa86ae4 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -48,16 +48,13 @@ struct TunableConfig { static constexpr size_t CHUNK_DIM_X = 128; static constexpr size_t THREADS_PER_CHUNK = 128; static constexpr size_t PREFETCH_STAGES = 1; - static constexpr PersistentStrategy PERSISTENT_STRATEGY = - PersistentStrategy::STATIC_GRID_STRIDE; + static constexpr PersistentStrategy PERSISTENT_STRATEGY = PersistentStrategy::STATIC_GRID_STRIDE; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; -constexpr bool DYNAMIC_PERSISTENT = - TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; -constexpr bool STATIC_PERSISTENT = - TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; +constexpr bool DYNAMIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; +constexpr bool STATIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; static_assert(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0, "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero."); From ab3f911a73124ef108456e2f29d810a8613c1777 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 18:12:36 +0000 Subject: [PATCH 17/31] Refactoring Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 261 +++++++++--------- 1 file changed, 123 insertions(+), 138 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 26baa86ae4..47f7131e76 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -152,6 +152,84 @@ __device__ __forceinline__ size_t get_tensor_cols_num( return cols_num; } +// Logical work-item decoded from CTA coordinates. +struct JobDescriptor { + size_t block_id = 0; + size_t block_global_offset = 0; + size_t tensor_id = 0; + size_t rows = 0; + size_t cols = 0; +}; + +// Tensor-local coordinates for a work-item. +struct BlockDescriptor { + size_t tensor_base = 0; + size_t block_id_in_current_tensor = 0; + size_t block_id_Y = 0; + size_t block_id_X = 0; + size_t block_offset_Y = 0; + size_t block_offset_X = 0; +}; + +__device__ __forceinline__ JobDescriptor decode_job( + const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, + const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { + JobDescriptor job{}; + job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); + job.block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); + job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + job.rows = + get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + return job; +} + +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, + const ShapeRepresentation shape_rep, + const size_t total_work_blocks, + const int64_t *const __restrict__ offsets_ptr) { + bool is_valid = (job.block_id < total_work_blocks) && (job.rows != 0) && (job.cols != 0); + if (!is_valid || shape_rep == SAME_BOTH_DIMS) { + return is_valid; + } + + const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); + if (job.block_global_offset >= tensor_end_offset) { + return false; + } + + const size_t tensor_offset_from_start = job.block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / job.cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % job.cols; + if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) { + return false; + } + + return true; +} + +__device__ __forceinline__ BlockDescriptor decode_block(const JobDescriptor &job, + const bool is_single_tensor, + const int64_t *const __restrict__ offsets_ptr) { + BlockDescriptor block{}; + block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(128)); + block.block_id_in_current_tensor = + is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK); + block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor; + block.block_id_X = block.block_id_in_current_tensor % blocks_X_num_in_current_tensor; + block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y; + block.block_offset_X = block.block_id_X * CHUNK_DIM_X; + return block; +} + // Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, CUtensorMap *global_tensor_map, @@ -335,6 +413,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + // - workID_mbar synchronizes WorkID query response in dynamic persistent mode. if (leading_thread) { #pragma unroll for (int buff = 0; buff < BUFFS_NUM; ++buff) { @@ -355,6 +436,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel int32_t ctaid_Y = static_cast(blockIdx.y); [[maybe_unused]] size_t static_next_block_id = 0; [[maybe_unused]] size_t static_block_stride = 0; + // In STATIC_PERSISTENT mode physical CTAs iterate over a virtual work grid via grid-stride. if constexpr (STATIC_PERSISTENT) { if (launch_block_id >= total_work_blocks) { return; @@ -368,50 +450,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel int buff_in = 0; bool has_prefetched_current_job = true; - // Prefetch the first stage of the first job. + // Prime the pipeline with stage-0 of the first job assigned to this CTA. { - const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); - - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - if (block_ID >= total_work_blocks || rows == 0 || cols == 0) { + const JobDescriptor first_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + if (!is_job_valid(first_job, shape_rep, total_work_blocks, offsets_ptr)) { return; } - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - return; - } - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - return; - } - } - - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + const BlockDescriptor first_block = decode_block(first_job, is_single_tensor, offsets_ptr); const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[first_job.tensor_id]; const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[first_job.tensor_id]; if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); @@ -424,8 +476,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { const size_t buff = stage; const size_t stage_offset_Y = stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; + const size_t global_offset_Y = first_block.block_offset_Y + stage_offset_Y; + const size_t global_offset_X = first_block.block_offset_X; const size_t buff_offset = buff * BUFF_DIM; uint64_t *barrier = &IN_buff_readable_mbar[buff]; if (leading_thread) { @@ -444,36 +496,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } + // Main persistent loop: decode current job, run all 32-row stages, schedule/prefetch next job. while (!job_finished) { - const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - - bool current_job_is_valid = (block_ID < total_work_blocks) && (rows != 0) && (cols != 0); - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - current_job_is_valid = false; - } - if (current_job_is_valid) { - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - current_job_is_valid = false; - } - } - } + // Decode CTA assignment into logical tensor coordinates and validate bounds. + const JobDescriptor current_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); if (!current_job_is_valid) { if (has_prefetched_current_job) { // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. @@ -485,21 +515,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel break; } + const size_t tensor_id = current_job.tensor_id; + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const BlockDescriptor current_block = decode_block(current_job, is_single_tensor, offsets_ptr); + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t tensor_base = current_block.tensor_base; const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) ? static_cast(offsets_ptr[tensor_id]) : tensor_base; - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; - - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + const size_t block_id_Y = current_block.block_id_Y; + const size_t block_id_X = current_block.block_id_X; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); @@ -560,6 +591,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } bool prefetched_next_job = false; + // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; @@ -595,39 +627,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - const size_t next_block_ID = - static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t next_block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (next_block_ID * ELTS_PER_CHUNK); - const size_t next_tensor_id = - get_current_tensor_id(shape_rep, num_tensors, next_block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - const size_t next_rows = - get_tensor_rows_num(next_tensor_id, shape_rep, first_logical_dim, first_dims_ptr, - num_tensors); - const size_t next_cols = - get_tensor_cols_num(next_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - - bool next_job_is_valid = - (next_block_ID < total_work_blocks) && (next_rows != 0) && (next_cols != 0); - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[next_tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[next_tensor_id + 1]); - if (next_block_global_offset >= tensor_end_offset) { - next_job_is_valid = false; - } - if (next_job_is_valid) { - const size_t tensor_offset_from_start = next_block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / next_cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % next_cols; - if (block_offset_Y_in_tensor >= next_rows || block_offset_X_in_tensor >= next_cols) { - next_job_is_valid = false; - } - } - } - allow_next_job_prefetch = next_job_is_valid; + const JobDescriptor next_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + allow_next_job_prefetch = is_job_valid(next_job, shape_rep, total_work_blocks, offsets_ptr); } if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { @@ -638,37 +641,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel prefetched_next_job = true; } - const size_t prefetch_block_ID = - static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t prefetch_block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (prefetch_block_ID * ELTS_PER_CHUNK); - const size_t prefetch_tensor_id = - get_current_tensor_id(shape_rep, num_tensors, prefetch_block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - const size_t prefetch_tensor_base = - is_single_tensor ? 0 : static_cast(offsets_ptr[prefetch_tensor_id]); - const size_t prefetch_cols = - get_tensor_cols_num(prefetch_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - const size_t prefetch_blocks_X_num_in_current_tensor = - DIVUP(prefetch_cols, static_cast(128)); - const size_t prefetch_block_id_in_current_tensor = - is_single_tensor ? prefetch_block_ID - : (prefetch_block_ID - prefetch_tensor_base / ELTS_PER_CHUNK); - const size_t prefetch_block_id_Y = - prefetch_block_id_in_current_tensor / prefetch_blocks_X_num_in_current_tensor; - const size_t prefetch_block_id_X = - prefetch_block_id_in_current_tensor % prefetch_blocks_X_num_in_current_tensor; - - const size_t global_offset_Y = prefetch_block_id_Y * CHUNK_DIM_Y + next_prefetch_stage_offset_Y; - const size_t global_offset_X = prefetch_block_id_X * CHUNK_DIM_X; + const JobDescriptor prefetch_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const BlockDescriptor prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); + + const size_t global_offset_Y = prefetch_block.block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = prefetch_block.block_offset_X; const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; const CUtensorMap &prefetch_tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_tensor_id]; + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_job.tensor_id]; const CUtensorMap &prefetch_tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_tensor_id]; + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_job.tensor_id]; uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; if (leading_thread) { From 92720ac69c475b737dd6d429ae5235d331a9f0a9 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 18:40:59 +0000 Subject: [PATCH 18/31] Refactoring 2 Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 121 +++++++++++------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 47f7131e76..fde1bf02c6 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -329,6 +329,51 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +// Issue TMA global->shared transfer for one stage of input (and optional activation input). +template +__device__ __forceinline__ void prefetch_input_stage( + IType *in_sh, IType *act_in_sh, const CUtensorMap &tensor_map_input, + const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, const size_t global_offset_Y, + const size_t buff_offset, const size_t shmem_buff_size, uint64_t *barrier, const bool leading_thread) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buff_offset]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + barrier); + if constexpr (IS_DACT) { + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[buff_offset]), + reinterpret_cast(&tensor_map_act_input), global_offset_X, global_offset_Y, + barrier); + } + } +} + +// Issue TMA shared->global transfer for one stage of outputs. +template +__device__ __forceinline__ void store_output_stage( + OType *out_rowwise_data_sh, OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, + const int global_offset_X, const int global_offset_Y, const int buff_offset, + const bool leading_thread) { + if (!leading_thread) { + return; + } + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, global_offset_Y, + reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, global_offset_Y, + reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + ptx::cp_async_bulk_commit_group(); +} + template @@ -480,19 +525,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t global_offset_X = first_block.block_offset_X; const size_t buff_offset = buff * BUFF_DIM; uint64_t *barrier = &IN_buff_readable_mbar[buff]; - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_sh[buff_offset]), - reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, - barrier); - if constexpr (IS_DACT) { - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&act_in_sh[buff_offset]), - reinterpret_cast(&tensor_map_act_input), global_offset_X, - global_offset_Y, barrier); - } - } + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, buff_offset, + shmem_buff_size, barrier, leading_thread); } } @@ -596,6 +631,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; bool allow_next_job_prefetch = true; + JobDescriptor prefetch_job = current_job; + BlockDescriptor prefetch_block = current_block; if (stage == STAGES - PREFETCH_STAGES) { if constexpr (DYNAMIC_PERSISTENT) { @@ -627,10 +664,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - const JobDescriptor next_job = + prefetch_job = decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - allow_next_job_prefetch = is_job_valid(next_job, shape_rep, total_work_blocks, offsets_ptr); + allow_next_job_prefetch = is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); + if (allow_next_job_prefetch) { + prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); + } } if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { @@ -641,11 +681,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel prefetched_next_job = true; } - const JobDescriptor prefetch_job = - decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, - work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - const BlockDescriptor prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); - const size_t global_offset_Y = prefetch_block.block_offset_Y + next_prefetch_stage_offset_Y; const size_t global_offset_X = prefetch_block.block_offset_X; const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; @@ -663,18 +698,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel fence_acquire_tensormap(&prefetch_tensor_map_act_input); } } - ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_sh[next_prefetch_buff_offset]), - reinterpret_cast(&prefetch_tensor_map_input), global_offset_X, - global_offset_Y, barrier); - if constexpr (IS_DACT) { - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&act_in_sh[next_prefetch_buff_offset]), - reinterpret_cast(&prefetch_tensor_map_act_input), global_offset_X, - global_offset_Y, barrier); - } } + prefetch_input_stage( + in_sh, act_in_sh, prefetch_tensor_map_input, prefetch_tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); ptx::fence_proxy_async_shared_cta(); } @@ -686,6 +713,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t buff = buff_in; float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { + // Column-wise path: + // 1) load/compute values for one [32x1] stripe per thread + // 2) compute/write E8M0 scale + // 3) scale and write FP8 values into shared output buffer const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; thread_amax = 0.0f; float in_compute_colwise[BUFF_DIM_Y]; @@ -766,6 +797,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } if constexpr (ROWWISE_SCALING) { + // Row-wise path: + // 1) load/compute values for one [1x32] stripe per thread + // 2) compute/write E8M0 scale + // 3) scale and write FP8 values into shared output buffer const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; thread_amax = 0.0f; @@ -904,23 +939,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ptx::fence_proxy_async_shared_cta(); __syncthreads(); - if (leading_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } - ptx::cp_async_bulk_commit_group(); - } + // Publish the stage from shared memory into global outputs via TMA. + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + store_output_stage( + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, tensor_map_output_colwise, + global_offset_X, global_offset_Y, buff_offset, leading_thread); buff_in = (buff_in + 1) % BUFFS_NUM; } From 46d9811122e1188294b5b50d9806d1b4793d52a4 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 19:17:28 +0000 Subject: [PATCH 19/31] Refactoring 3 Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 497 ++++++++++-------- 1 file changed, 276 insertions(+), 221 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index fde1bf02c6..67537e8dfd 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -374,6 +374,271 @@ __device__ __forceinline__ void store_output_stage( ptx::cp_async_bulk_commit_group(); } +template +__device__ __forceinline__ float process_colwise_stage( + const size_t buff, const int stage, const size_t tid_X_colwise, + const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, + const size_t scale_stride_colwise, const size_t tensor_base_for_scales, const size_t rows, + const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, + OType *out_colwise_data_sh, e8m0_t *scales_colwise, float &partial_dbias_colwise) { + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; + if constexpr (!IS_CACHED_ACT_OP) { + (void)cached_act_sh; + } + if constexpr (!IS_DACT) { + (void)act_in_sh; + } + if constexpr (!IS_DBIAS) { + (void)partial_dbias_colwise; + } + if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { + (void)tensor_base_for_scales; + (void)rows; + } + + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + float thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } + + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = tensor_scales_offset_colwise_base + + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + global_scales_offset_X, local_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + + return thread_amax; +} + +template +__device__ __forceinline__ float process_rowwise_stage( + const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, + const size_t thread_offset_X_rowwise, const int bank_group, + const size_t scales_offset_Y_rowwise, const size_t scales_offset_X_rowwise, + const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, + const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, + OType *out_rowwise_data_sh, e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; + if constexpr (!IS_DACT) { + (void)act_in_sh; + } + if constexpr (!IS_CACHED_ACT_OP) { + (void)cached_act_sh; + } + if constexpr (!(IS_DBIAS && (!COLWISE_SCALING))) { + (void)thread_dbias_rowwise; + } + if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { + (void)cols; + } + + const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + float thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + Vec in_IType[WAVES]; + + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } + } + } + + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + + return thread_amax; +} + template @@ -393,19 +658,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; - if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { return; } } - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); const bool leading_thread = (threadIdx.x == 0); @@ -713,223 +971,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t buff = buff_in; float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { - // Column-wise path: - // 1) load/compute values for one [32x1] stripe per thread - // 2) compute/write E8M0 scale - // 3) scale and write FP8 values into shared output buffer - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_colwise[i] = elt; - } - } - - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - const size_t tensor_base_row = tensor_base_for_scales / cols; - const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; - const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); - } else { - scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - } - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } + thread_amax = process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, cached_act_sh, + out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } if constexpr (ROWWISE_SCALING) { - // Row-wise path: - // 1) load/compute values for one [1x32] stripe per thread - // 2) compute/write E8M0 scale - // 3) scale and write FP8 values into shared output buffer - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - Vec in_IType[WAVES]; - - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; - } - } - } - - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, - DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - } - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } + thread_amax = process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, + scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, out_rowwise_data_sh, + scales_rowwise, thread_dbias_rowwise); } __builtin_assume(block_amax >= 0); From 71724007c5e9a200b2266ee687a59b269c119b67 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 5 Mar 2026 13:38:12 +0000 Subject: [PATCH 20/31] Removed the dynamic (WorkID Query) persistency Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 92 ++++--------------- 1 file changed, 20 insertions(+), 72 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 67537e8dfd..7a510c1295 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -37,26 +37,21 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; -enum class PersistentStrategy : int { - NONE = 0, - DYNAMIC_WORK_STEALING = 1, - STATIC_GRID_STRIDE = 2, -}; - struct TunableConfig { static constexpr size_t CHUNK_DIM_Y = 128; static constexpr size_t CHUNK_DIM_X = 128; static constexpr size_t THREADS_PER_CHUNK = 128; static constexpr size_t PREFETCH_STAGES = 1; - static constexpr PersistentStrategy PERSISTENT_STRATEGY = PersistentStrategy::STATIC_GRID_STRIDE; + // true -> static persistent grid-stride scheduler + // false -> non-persistent one-job-per-CTA execution + static constexpr bool PERSISTENT = true; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; -constexpr bool DYNAMIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; -constexpr bool STATIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; -static_assert(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0, - "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero."); +constexpr bool PERSISTENT = TunableConfig::PERSISTENT; +static_assert(!PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0), + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -177,11 +172,10 @@ __device__ __forceinline__ JobDescriptor decode_job( const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { JobDescriptor job{}; - job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - job.block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (job.block_id * ELTS_PER_CHUNK); + job.block_id = ctaid_Y * work_blocks_X + ctaid_X; + job.block_global_offset = is_single_tensor + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, first_logical_dim, last_logical_dim, offsets_ptr); job.rows = @@ -386,19 +380,6 @@ __device__ __forceinline__ float process_colwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; - if constexpr (!IS_CACHED_ACT_OP) { - (void)cached_act_sh; - } - if constexpr (!IS_DACT) { - (void)act_in_sh; - } - if constexpr (!IS_DBIAS) { - (void)partial_dbias_colwise; - } - if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { - (void)tensor_base_for_scales; - (void)rows; - } const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; float thread_amax = 0.0f; @@ -496,18 +477,6 @@ __device__ __forceinline__ float process_rowwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; - if constexpr (!IS_DACT) { - (void)act_in_sh; - } - if constexpr (!IS_CACHED_ACT_OP) { - (void)cached_act_sh; - } - if constexpr (!(IS_DBIAS && (!COLWISE_SCALING))) { - (void)thread_dbias_rowwise; - } - if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { - (void)cols; - } const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; float thread_amax = 0.0f; @@ -709,44 +678,35 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel float block_amax = 0.0f; - __shared__ uint64_t workID_mbar; - [[maybe_unused]] __shared__ __uint128_t workID_response; - [[maybe_unused]] constexpr uint32_t workID_response_size = sizeof(workID_response); - static_assert(workID_response_size == 16); - __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; // Initialize barriers shared by the entire CTA: // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. - // - workID_mbar synchronizes WorkID query response in dynamic persistent mode. if (leading_thread) { #pragma unroll for (int buff = 0; buff < BUFFS_NUM; ++buff) { ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); } - ptx::mbarrier_init(&workID_mbar, 1); ptx::fence_proxy_async_shared_cta(); } __syncthreads(); const size_t total_work_blocks = work_blocks_X * work_blocks_Y; - const size_t launch_block_id = - static_cast(blockIdx.y) * static_cast(gridDim.x) + static_cast(blockIdx.x); + const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; int IN_buff_readable_parity[BUFFS_NUM] = {0}; - [[maybe_unused]] int ctaid_parity = 0; int32_t ctaid_X = static_cast(blockIdx.x); int32_t ctaid_Y = static_cast(blockIdx.y); - [[maybe_unused]] size_t static_next_block_id = 0; - [[maybe_unused]] size_t static_block_stride = 0; - // In STATIC_PERSISTENT mode physical CTAs iterate over a virtual work grid via grid-stride. - if constexpr (STATIC_PERSISTENT) { + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. + if constexpr (PERSISTENT) { if (launch_block_id >= total_work_blocks) { return; } ctaid_X = static_cast(launch_block_id % work_blocks_X); ctaid_Y = static_cast(launch_block_id / work_blocks_X); - static_block_stride = static_cast(gridDim.x) * static_cast(gridDim.y); + static_block_stride = gridDim.x * gridDim.y; static_next_block_id = launch_block_id + static_block_stride; } bool job_finished = false; @@ -789,7 +749,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - // Main persistent loop: decode current job, run all 32-row stages, schedule/prefetch next job. + // Main work loop: decode current job, run all 32-row stages, schedule/prefetch next job. while (!job_finished) { // Decode CTA assignment into logical tensor coordinates and validate bounds. const JobDescriptor current_job = @@ -876,13 +836,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if constexpr (DYNAMIC_PERSISTENT) { - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); - ptx::try_cancel_cta(&workID_mbar, &workID_response); - } - } - bool prefetched_next_job = false; // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll @@ -893,11 +846,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel BlockDescriptor prefetch_block = current_block; if (stage == STAGES - PREFETCH_STAGES) { - if constexpr (DYNAMIC_PERSISTENT) { - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); - ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); - ctaid_parity ^= 1; - } else if constexpr (STATIC_PERSISTENT) { + if constexpr (PERSISTENT) { if (static_next_block_id < total_work_blocks) { ctaid_X = static_cast(static_next_block_id % work_blocks_X); ctaid_Y = static_cast(static_next_block_id / work_blocks_X); @@ -912,7 +861,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ctaid_X = -1; ctaid_Y = -1; } - if constexpr (!STATIC_PERSISTENT) { + if constexpr (!PERSISTENT) { if (ctaid_X == -1 && ctaid_Y == -1) { job_finished = true; } @@ -1062,7 +1011,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel for (int buff = 0; buff < BUFFS_NUM; ++buff) { ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); } - ptx::mbarrier_invalid(&workID_mbar); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -1139,7 +1087,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations size_t launch_blocks_X = work_blocks_X; size_t launch_blocks_Y = work_blocks_Y; - if constexpr (STATIC_PERSISTENT) { + if constexpr (PERSISTENT) { const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); const size_t static_grid_size = sm_num * TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); From 4344627e32c3b14279b07aca1d16294b965277e9 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 5 Mar 2026 15:57:41 +0000 Subject: [PATCH 21/31] Ready for PR Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 56 +- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 41 +- .../common/activation/activation_template.h | 30 +- .../common/cast/dispatch/dequantize.cuh | 52 +- .../common/cast/dispatch/gated.cuh | 304 ++++---- .../common/cast/dispatch/quantize.cuh | 729 +++++++++--------- 7 files changed, 604 insertions(+), 611 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 2092975b2a..6f4f163f08 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,8 +8,7 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) - set(CMAKE_CUDA_ARCHITECTURES 100) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index a04cc3c38c..56880a428d 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,35 +3,35 @@ # See LICENSE for license information. add_executable(test_operator - # test_cast.cu - # test_cast_current_scaling.cu - # test_cast_dbias.cu - # test_cast_dbias_dgelu.cu - # test_cast_gated_swiglu.cu - # test_cast_mxfp8_gated_swiglu.cu - # test_qdq.cu - # test_cast_mxfp8.cu + test_cast.cu + test_cast_current_scaling.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu + test_qdq.cu + test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu - # test_cast_nvfp4_transpose.cu - # test_cast_float8blockwise.cu - # test_dequantize_mxfp8.cu - # test_transpose.cu - # test_cast_transpose.cu - # test_cast_transpose_current_scaling.cu - # test_cast_transpose_dbias.cu - # test_cast_transpose_dbias_dgelu.cu - # test_cast_transpose_dgeglu.cu - # test_act.cu - # test_normalization.cu - # test_normalization_mxfp8.cu - # test_memset.cu - # test_multi_cast_transpose.cu - # test_multi_padding.cu - # test_multi_unpadding.cu - # test_causal_softmax.cu - # test_swizzle.cu - # test_swap_first_dims.cu - # test_grouped_gemm.cu + test_cast_nvfp4_transpose.cu + test_cast_float8blockwise.cu + test_dequantize_mxfp8.cu + test_transpose.cu + test_cast_transpose.cu + test_cast_transpose_current_scaling.cu + test_cast_transpose_dbias.cu + test_cast_transpose_dbias_dgelu.cu + test_cast_transpose_dgeglu.cu + test_act.cu + test_normalization.cu + test_normalization_mxfp8.cu + test_memset.cu + test_multi_cast_transpose.cu + test_multi_padding.cu + test_multi_unpadding.cu + test_causal_softmax.cu + test_swizzle.cu + test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 6cff159d51..647737171a 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -623,15 +623,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - // ProcessingMethod::CAST_DBIAS, - // ProcessingMethod::CAST_DBIAS_DACT, - // ProcessingMethod::CAST_DACT, - // ProcessingMethod::CAST_ACT, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { ActivationKind::Identity, - // ActivationKind::GeLU, + ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -646,23 +646,22 @@ enum ScalingDirection { std::vector scaling_directions = { ScalingDirection::ROWWISE, - // ScalingDirection::COLWISE, - // ScalingDirection::BOTH, + ScalingDirection::COLWISE, + ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - // {SAME_BOTH_DIMS, 1, 128,128}, - // {SAME_BOTH_DIMS, 2, 256,128}, - // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, - // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - {VARYING_FIRST_DIM, 5, 4096,4096, 128,256,384,1024,2304}, - {VARYING_FIRST_DIM, 5, 16 * 4096,4096, 128,256,384,1024,2304}, - // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -824,10 +823,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kBFloat16), - ::testing::Values(DType::kFloat8E4M3)), - // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index caf6cbda65..ffbffafd1a 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -22,36 +22,36 @@ namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - // using namespace detail; - // constexpr bool IS_ACT = true; - // dispatch::quantize_fwd_helper(input, output, nullptr, stream); + using namespace detail; + constexpr bool IS_ACT = true; + dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { - // using namespace detail; - // constexpr bool IS_DBIAS = false; - // constexpr bool IS_DACT = true; - // constexpr NVTETensor dbias = nullptr; - // constexpr NVTETensor workspace = nullptr; - - // dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, - // nullptr, stream); + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + + dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - // using namespace detail; - // dispatch::quantize_gated_fwd_helper(input, output, p, stream); + using namespace detail; + dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - // using namespace detail; - // dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); + using namespace detail; + dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index db2ad285a8..81304981d3 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -22,32 +22,32 @@ namespace transformer_engine { namespace dispatch { inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - // CheckInputTensor(input, "cast_input"); - // CheckOutputTensor(*output, "cast_output"); - - // switch (input.scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); - // NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); - // NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); - // fp8::dequantize(input, output, stream); - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // if (is_supported_by_CC_100()) { - // mxfp8::dequantize(input, output, stream); - // } else { - // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - // } - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // nvfp4::dequantize(input, output, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - // } + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } } } // namespace dispatch diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index c2087533a6..06e8f0e306 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -25,164 +25,164 @@ namespace dispatch { template void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - // const Tensor input = *convertNVTETensorCheck(nvte_input); - // Tensor *output = convertNVTETensorCheck(nvte_output); - - // CheckInputTensor(input, "input"); - // CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - // const size_t rows = input.flat_first_dim(); - // const size_t cols = input.flat_last_dim() / 2; - - // NVTE_CHECK(input.flat_last_dim() % 2 == 0, - // "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - // input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - // NVTE_CHECK(output->flat_last_dim() == cols, - // "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", - // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // switch (output->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - // if (use_tma_kernels) { - // Tensor dummy_grad_tensor; - // fp8::cast_gated_tma(input, dummy_grad_tensor, - // output, p, stream); - // } else { - // fp8::cast_gated_fwd(input, output, p, stream); - // } - // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // // FP8 kernel only populates row-wise data, so perform - // // transpose separately if needed - // Tensor transpose_in, transpose_out, dummy; - // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_in.data.dptr = output->data.dptr; - // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - // transpose_in.data.dtype = output->data.dtype; - // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_out.data.dptr = output->columnwise_data.dptr; - // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - // transpose_out.data.dtype = output->data.dtype; - // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // NVTE_CHECK(cols % 32 == 0, - // "Invalid input shape. Expected the last dimension to be " - // "divisible by 32, but got ", - // cols, "."); - // if (output->has_data()) { - // NVTE_CHECK(is_fp8_dtype(output->data.dtype), - // "The type of the output tensor should be FP8."); - // } - // if (output->has_columnwise_data()) { - // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - // "The type of the columnwise output tensor should be FP8."); - // } - // NVTE_CHECK(is_supported_by_CC_100(), - // "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - // Tensor dummy_grad_tensor; - // mxfp8::quantize_gated(input, dummy_grad_tensor, - // output, p, stream); - // break; - // } - // default: - // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - // } + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols, + "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + Tensor dummy_grad_tensor; + fp8::cast_gated_tma(input, dummy_grad_tensor, + output, p, stream); + } else { + fp8::cast_gated_fwd(input, output, p, stream); + } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + Tensor dummy_grad_tensor; + mxfp8::quantize_gated(input, dummy_grad_tensor, + output, p, stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } } template void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - // const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); - // const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); - // Tensor *output = convertNVTETensorCheck(nvte_output); - - // CheckInputTensor(grad, "grad"); - // CheckInputTensor(gated_input, "gated_input"); - // CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - // NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", - // gated_input.flat_last_dim(), "."); - - // const size_t rows = gated_input.flat_first_dim(); - // const size_t cols = gated_input.flat_last_dim() / 2; - - // NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); - // NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); - - // NVTE_CHECK(grad.flat_first_dim() == rows, - // "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", - // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - // NVTE_CHECK(grad.flat_last_dim() == cols, - // "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", - // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", - // rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - // NVTE_CHECK(output->flat_last_dim() == cols * 2, - // "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", - // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - // NVTE_CHECK(gated_input.shape() == output->shape(), - // "Gated input and output shapes must match. Input shape: ", gated_input.shape(), - // ", output shape: ", output->shape(), "."); - - // switch (output->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - // if (use_tma_kernels) { - // fp8::cast_gated_tma(gated_input, grad, output, p, - // stream); - // } else { - // fp8::cast_gated_bwd(gated_input, grad, output, p, stream); - // } - // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // // FP8 kernel only populates row-wise data, so perform - // // transpose separately if needed - // Tensor transpose_in, transpose_out, dummy; - // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_in.data.dptr = output->data.dptr; - // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - // transpose_in.data.dtype = output->data.dtype; - // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_out.data.dptr = output->columnwise_data.dptr; - // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - // transpose_out.data.dtype = output->data.dtype; - // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // NVTE_CHECK(cols % 32 == 0, - // "Invalid input shape. Expected the last dimension to be " - // "divisible by 32, but got ", - // cols, "."); - // if (output->has_data()) { - // NVTE_CHECK(is_fp8_dtype(output->data.dtype), - // "The type of the output tensor should be FP8."); - // } - // if (output->has_columnwise_data()) { - // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - // "The type of the columnwise output tensor should be FP8."); - // } - // NVTE_CHECK(is_supported_by_CC_100(), - // "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - - // mxfp8::quantize_gated(gated_input, grad, output, p, - // stream); - // break; - // } - // default: - // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - // } + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(grad, "grad"); + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + gated_input.flat_last_dim(), "."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + + NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); + NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); + + NVTE_CHECK(grad.flat_first_dim() == rows, + "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + NVTE_CHECK(grad.flat_last_dim() == cols, + "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols * 2, + "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(gated_input.shape() == output->shape(), + "Gated input and output shapes must match. Input shape: ", gated_input.shape(), + ", output shape: ", output->shape(), "."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + fp8::cast_gated_tma(gated_input, grad, output, p, + stream); + } else { + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + mxfp8::quantize_gated(gated_input, grad, output, p, + stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } } } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 0aadffa940..f7823b4c58 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -30,282 +30,282 @@ namespace dispatch { template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *input_tensor = convertNVTETensorCheck(input); - // Tensor *output_tensor = convertNVTETensorCheck(output); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // // Dispatch to quantization kernel depending on data format - // switch (output_tensor->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const Tensor *dummy_input_tensor = nullptr; - // Tensor *dummy_dbias_tensor = nullptr; - // Tensor *dummy_workspace_tensor = nullptr; - // if (output_tensor->has_columnwise_data()) { - // NVTE_CHECK(output_tensor->has_data(), - // "Quantizing in only the columnwise direction not supported yet!"); - // if constexpr (!IS_ACT) { - // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - // } else { - // cast_transpose_fused( - // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // } - // } else if (output_tensor->has_data()) { - // fp8::quantize( - // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // const Tensor *dummy_input_tensor = nullptr; - // Tensor *dummy_dbias_tensor = nullptr; - // Tensor *dummy_workspace_tensor = nullptr; - // mxfp8::quantize( - // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*input_tensor, "input"); - // CheckOutputTensor(*output_tensor, "output", false); - - // // Choose kernel - // int32_t rows = input_tensor->flat_first_dim(); - // int32_t cols = input_tensor->flat_last_dim(); - // auto dtype = input_tensor->dtype(); - // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - // (cols % 32 == 0) && output_tensor->has_data(); - - // // Launch NVFP4 quantize kernel - // if (use_optimized_kernel) { - // if (quant_config_cpp.nvfp4_2d_quantization) { - // nvfp4::quantize_transpose( - // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } else { - // nvfp4::quantize_transpose( - // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } - // } else { - // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - // : output_tensor->columnwise_amax; - // quantize_transpose_vector_blockwise_fp4( - // /*input=*/input_tensor->data, /*global_amax=*/global_amax, - // /*scale_inv=*/output_tensor->scale_inv, - // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - // /*swizzled_scale=*/false, - // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - // /*rng_state=*/quant_config_cpp.rng_state, - // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - // } - // break; - // } - // case NVTE_BLOCK_SCALING_2D: { - // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // quantize_transpose_square_blockwise( - // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, - // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - // /*noop_tensor=*/noop_tensor->data, stream); - // break; - // } - // case NVTE_BLOCK_SCALING_1D: { - // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - // if (output_tensor->has_data()) { - // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - // } - // if (output_tensor->has_columnwise_data()) { - // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - // } - // quantize_transpose_vector_blockwise( - // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + Tensor *output_tensor = convertNVTETensorCheck(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + mxfp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } } template void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *grad_tensor = convertNVTETensorCheck(grad); - // const Tensor *input_tensor = convertNVTETensor(input); - - // Tensor *output_tensor = convertNVTETensorCheck(output); - // Tensor *dbias_tensor = convertNVTETensor(dbias); - // Tensor *workspace_tensor = convertNVTETensor(workspace); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // // Dispatch to quantization kernel depending on data format - // switch (output_tensor->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // if (output_tensor->has_columnwise_data()) { - // NVTE_CHECK(output_tensor->has_data(), - // "Quantizing in only the columnwise direction not supported yet!"); - // if constexpr (!IS_DBIAS && !IS_DACT) { - // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - // } else { - // cast_transpose_fused( - // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - // } - // } else if (output_tensor->has_data()) { - // fp8::quantize( - // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // mxfp8::quantize( - // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*grad_tensor, "input"); - // CheckOutputTensor(*output_tensor, "output", false); - - // // Choose kernel - // int32_t rows = grad_tensor->flat_first_dim(); - // int32_t cols = grad_tensor->flat_last_dim(); - // auto dtype = grad_tensor->dtype(); - // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - // (cols % 32 == 0) && output_tensor->has_data(); - - // // Launch NVFP4 quantize kernel - // if (use_optimized_kernel) { - // if (quant_config_cpp.nvfp4_2d_quantization) { - // nvfp4::quantize_transpose( - // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } else { - // nvfp4::quantize_transpose( - // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } - // } else { - // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - // : output_tensor->columnwise_amax; - // quantize_transpose_vector_blockwise_fp4( - // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - // /*scale_inv=*/output_tensor->scale_inv, - // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - // /*swizzled_scale=*/false, - // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - // /*rng_state=*/quant_config_cpp.rng_state, - // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - // } - // break; - // } - // case NVTE_BLOCK_SCALING_2D: { - // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // quantize_transpose_square_blockwise( - // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, - // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - // /*noop_tensor=*/noop_tensor->data, stream); - // break; - // } - // case NVTE_BLOCK_SCALING_1D: { - // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - // if (output_tensor->has_data()) { - // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - // } - // if (output_tensor->has_columnwise_data()) { - // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - // } - // quantize_transpose_vector_blockwise( - // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *grad_tensor = convertNVTETensorCheck(grad); + const Tensor *input_tensor = convertNVTETensor(input); + + Tensor *output_tensor = convertNVTETensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT) { + cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*grad_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = grad_tensor->flat_first_dim(); + int32_t cols = grad_tensor->flat_last_dim(); + auto dtype = grad_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -314,64 +314,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *input_tensor = convertNVTETensorCheck(input); - // std::vector output_tensors; - // for (size_t i = 0; i < num_tensors; ++i) { - // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - // } - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // // Take the scaling mode of the first output tensor - // auto scaling_mode = output_tensors[0]->scaling_mode; - - // // Dispatch to quantization kernel depending on data format - // switch (scaling_mode) { - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*input_tensor, "input"); - // // Skip checking output tensor list - // // output list here is allowed to have empty tensor - - // // Choose kernel - // int32_t rows = input_tensor->flat_first_dim(); - // int32_t cols = input_tensor->flat_last_dim(); - // auto dtype = input_tensor->dtype(); - - // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - // "2D quantization is not supported for group quantize."); - - // // Launch NVFP4 group quantize kernel - // nvfp4::group_quantize_transpose( - // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - // &quant_config_cpp, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_tensors; + for (size_t i = 0; i < num_tensors; ++i) { + output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + } + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + // Take the scaling mode of the first output tensor + auto scaling_mode = output_tensors[0]->scaling_mode; + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + // Skip checking output tensor list + // output list here is allowed to have empty tensor + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + + // Launch NVFP4 group quantize kernel + nvfp4::group_quantize_transpose( + *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + &quant_config_cpp, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } } template @@ -407,10 +407,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - // mxfp8::group_quantize( - // IS_ACT is set to false - // OP is set to nullptr - mxfp8::group_quantize( + mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; @@ -425,40 +422,40 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - // const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - // const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); - // GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - // GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - // Tensor *workspace_tensor = convertNVTETensor(workspace); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Dispatch to quantization kernel depending on data format - // switch (scaling_mode) { - // case NVTE_MXFP8_1D_SCALING: { - // mxfp8::group_quantize( - // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - // } + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } } } // namespace dispatch From ede33b43c4aa7b7d86de9c35444e1bde36b8ac8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:17:04 +0000 Subject: [PATCH 22/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 115 ++++++++++-------- 1 file changed, 63 insertions(+), 52 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 7a510c1295..b28fe1d820 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -17,9 +17,9 @@ #include #include "../../common.h" +#include "../../util/cuda_runtime.h" #include "../../util/math.h" #include "../../util/ptx.cuh" -#include "../../util/cuda_runtime.h" #include "../../utils.cuh" #include "../core/common.cuh" #include "swizzle.cuh" @@ -170,12 +170,13 @@ __device__ __forceinline__ JobDescriptor decode_job( const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr) { JobDescriptor job{}; job.block_id = ctaid_Y * work_blocks_X + ctaid_X; job.block_global_offset = is_single_tensor - ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) - : (job.block_id * ELTS_PER_CHUNK); + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, first_logical_dim, last_logical_dim, offsets_ptr); job.rows = @@ -209,9 +210,9 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, return true; } -__device__ __forceinline__ BlockDescriptor decode_block(const JobDescriptor &job, - const bool is_single_tensor, - const int64_t *const __restrict__ offsets_ptr) { +__device__ __forceinline__ BlockDescriptor +decode_block(const JobDescriptor &job, const bool is_single_tensor, + const int64_t *const __restrict__ offsets_ptr) { BlockDescriptor block{}; block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(128)); @@ -327,8 +328,9 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso template __device__ __forceinline__ void prefetch_input_stage( IType *in_sh, IType *act_in_sh, const CUtensorMap &tensor_map_input, - const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, const size_t global_offset_Y, - const size_t buff_offset, const size_t shmem_buff_size, uint64_t *barrier, const bool leading_thread) { + const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, + const size_t global_offset_Y, const size_t buff_offset, const size_t shmem_buff_size, + uint64_t *barrier, const bool leading_thread) { if (leading_thread) { ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( @@ -338,39 +340,41 @@ __device__ __forceinline__ void prefetch_input_stage( if constexpr (IS_DACT) { ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&act_in_sh[buff_offset]), - reinterpret_cast(&tensor_map_act_input), global_offset_X, global_offset_Y, - barrier); + reinterpret_cast(&tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); } } } // Issue TMA shared->global transfer for one stage of outputs. template -__device__ __forceinline__ void store_output_stage( - OType *out_rowwise_data_sh, OType *out_colwise_data_sh, - const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, - const int global_offset_X, const int global_offset_Y, const int buff_offset, - const bool leading_thread) { +__device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, + OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, + const CUtensorMap &tensor_map_output_colwise, + const int global_offset_X, + const int global_offset_Y, const int buff_offset, + const bool leading_thread) { if (!leading_thread) { return; } if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, global_offset_Y, - reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); } if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, global_offset_Y, - reinterpret_cast(&out_colwise_data_sh[buff_offset])); + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); } ptx::cp_async_bulk_commit_group(); } template + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, + bool WITH_GEMM_SWIZZLED_SCALES> __device__ __forceinline__ float process_colwise_stage( const size_t buff, const int stage, const size_t tid_X_colwise, const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, @@ -434,10 +438,10 @@ __device__ __forceinline__ float process_colwise_stage( const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( - global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); + scale_idx = + tensor_scales_offset_colwise_base + + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + global_scales_offset_X, local_scales_offset_Y, DIVUP(rows, static_cast(128))); } else { scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; } @@ -463,15 +467,15 @@ __device__ __forceinline__ float process_colwise_stage( } template + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool COLWISE_SCALING, + bool WITH_GEMM_SWIZZLED_SCALES> __device__ __forceinline__ float process_rowwise_stage( const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, const size_t thread_offset_X_rowwise, const int bank_group, const size_t scales_offset_Y_rowwise, const size_t scales_offset_X_rowwise, - const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, - const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, - OType *out_rowwise_data_sh, e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { + const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, const size_t cols, + IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_rowwise_data_sh, + e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; @@ -725,8 +729,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const CUtensorMap &tensor_map_input = is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[first_job.tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[first_job.tensor_id]; + const CUtensorMap &tensor_map_act_input = is_single_tensor + ? tensor_map_act_input_static + : g_tensor_maps_act_input[first_job.tensor_id]; if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); @@ -809,10 +814,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; const CUtensorMap &tensor_map_act_input = is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = - is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = - is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = is_single_tensor + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = is_single_tensor + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); @@ -871,10 +878,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - prefetch_job = - decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, - work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - allow_next_job_prefetch = is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); + prefetch_job = decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, + last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, + first_dims_ptr, last_dims_ptr); + allow_next_job_prefetch = + is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); if (allow_next_job_prefetch) { prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); } @@ -893,9 +901,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; const CUtensorMap &prefetch_tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_job.tensor_id]; + is_single_tensor ? tensor_map_input_static + : g_tensor_maps_input[prefetch_job.tensor_id]; const CUtensorMap &prefetch_tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_job.tensor_id]; + is_single_tensor ? tensor_map_act_input_static + : g_tensor_maps_act_input[prefetch_job.tensor_id]; uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; if (leading_thread) { @@ -906,9 +916,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } } - prefetch_input_stage( - in_sh, act_in_sh, prefetch_tensor_map_input, prefetch_tensor_map_act_input, global_offset_X, - global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); + prefetch_input_stage(in_sh, act_in_sh, prefetch_tensor_map_input, + prefetch_tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, + shmem_buff_size, barrier, leading_thread); ptx::fence_proxy_async_shared_cta(); } @@ -923,8 +934,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel thread_amax = process_colwise_stage( buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, - scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, cached_act_sh, - out_colwise_data_sh, scales_colwise, partial_dbias_colwise); + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } if constexpr (ROWWISE_SCALING) { @@ -932,8 +943,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel COLWISE_SCALING, WITH_GEMM_SWIZZLED_SCALES>( buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, - rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, out_rowwise_data_sh, - scales_rowwise, thread_dbias_rowwise); + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, + out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); } __builtin_assume(block_amax >= 0); @@ -948,8 +959,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int global_offset_X = block_offset_X; const int buff_offset = buff * BUFF_DIM; store_output_stage( - out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, tensor_map_output_colwise, - global_offset_X, global_offset_Y, buff_offset, leading_thread); + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); buff_in = (buff_in + 1) % BUFFS_NUM; } From 325181bd23e61923c2e1cf174ba9997249fed800 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 10:37:56 +0000 Subject: [PATCH 23/31] Fixes per the review Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 2 +- transformer_engine/common/cast/core/common.cuh | 4 ++-- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 7 ++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 647737171a..e54ceebaa3 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -371,7 +371,7 @@ void performTest(const ProcessingMethod processing_method, NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); - std::vector dbias_logical_shape_vec= {num_tensors, cols}; + std::vector dbias_logical_shape_vec = {num_tensors, cols}; NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), dbias_logical_shape_vec.size()); diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index a4e033939b..ce9fce6285 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -100,14 +100,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t tensor_id = blockIdx.y; const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) ? (first_logical_dim / num_tensors) - : first_dims_ptr[tensor_id]; + : static_cast(first_dims_ptr[tensor_id]); const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index b28fe1d820..8c452f2a7a 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -142,6 +142,10 @@ __device__ __forceinline__ size_t get_tensor_cols_num( case ShapeRepresentation::VARYING_LAST_DIM: case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); + if (cols_num % 128 != 0) { + NVTE_DEVICE_ERROR("For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); + } break; } return cols_num; @@ -215,7 +219,8 @@ decode_block(const JobDescriptor &job, const bool is_single_tensor, const int64_t *const __restrict__ offsets_ptr) { BlockDescriptor block{}; block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); - const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(128)); + const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); block.block_id_in_current_tensor = is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK); block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor; From 5815335bd1d016595d14d77b87e73e6fc314c8c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 10:40:30 +0000 Subject: [PATCH 24/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/core/common.cuh | 7 ++++--- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index ce9fce6285..9c16666db0 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -105,9 +105,10 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; - const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); + const size_t dbias_in_offset_Y = + (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 8c452f2a7a..6e9bd3dc5e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -143,8 +143,9 @@ __device__ __forceinline__ size_t get_tensor_cols_num( case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); if (cols_num % 128 != 0) { - NVTE_DEVICE_ERROR("For non-single tensors, the last dimension of each tensor in a group " - "must be divisible by 128."); + NVTE_DEVICE_ERROR( + "For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); } break; } From 0bd837c4c2842a976ed26709d3e9464ab79ac0ac Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 13:28:39 +0000 Subject: [PATCH 25/31] Added the test suite Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 59 +- .../test_cast_nvfp4_transpose_grouped.cu | 567 ++++++++++++++++++ 3 files changed, 599 insertions(+), 30 deletions(-) create mode 100644 tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 6f4f163f08..2092975b2a 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,7 +8,8 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + set(CMAKE_CUDA_ARCHITECTURES 100) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 56880a428d..f88402195b 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,35 +3,36 @@ # See LICENSE for license information. add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu - test_cast_mxfp8_grouped.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu - test_grouped_gemm.cu + # test_cast.cu + # test_cast_current_scaling.cu + # test_cast_dbias.cu + # test_cast_dbias_dgelu.cu + # test_cast_gated_swiglu.cu + # test_cast_mxfp8_gated_swiglu.cu + # test_qdq.cu + # test_cast_mxfp8.cu + # test_cast_mxfp8_grouped.cu + # test_cast_nvfp4_transpose.cu + test_cast_nvfp4_transpose_grouped.cu + # test_cast_float8blockwise.cu + # test_dequantize_mxfp8.cu + # test_transpose.cu + # test_cast_transpose.cu + # test_cast_transpose_current_scaling.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + # test_normalization.cu + # test_normalization_mxfp8.cu + # test_memset.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_multi_unpadding.cu + # test_causal_softmax.cu + # test_swizzle.cu + # test_swap_first_dims.cu + # test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu b/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu new file mode 100644 index 0000000000..def327b8c9 --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu @@ -0,0 +1,567 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) { + std::vector input_t(cols * rows); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const size_t idx_t = j * rows + i; + input_t[idx_t] = input[idx]; + } + } + return input_t; +} + +// Compute the global encode scale factor for a given global amax +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { + constexpr float fp8_max = 448.0f; // 448.0f; + constexpr float fp4_max = 6.0f; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return the max normalized value + const float max_norm_clamp = use_fast_math + ? Numeric_Traits::maxNorm + : Numeric_Traits::maxNorm; + + global_encode_scale = fminf(global_encode_scale, max_norm_clamp); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +// 1D Scaling: Original implementation with 1x16 blocks +template +void quantize_nvfp4_1d(const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_fast_math) { + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + + constexpr size_t block_size_X = 16; + const size_t blocks_X = divide_round_up(cols, block_size_X); + + std::array cache_buffer; + for (size_t i = 0; i < block_size_X; ++i) { + cache_buffer[i] = 0.0f; + } + + for (size_t i = 0; i < rows; ++i) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t j_min = block_X * block_size_X; + const size_t j_max = j_min + block_size_X; + + // Find block amax + float block_amax = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = input_elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx] = elt; + block_amax = std::max(block_amax, std::abs(elt)); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = S_dec_b_fp8; + + float scale_reciprocal = S_enc_b_fp8; + if (use_fast_math) { + // Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + + const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + } + } + } +} + +template +void quantize_nvfp4(const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_fast_math) { + quantize_nvfp4_1d(input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); +} + +template +void compute_ref(const InputType* input, + fp4e2m1x2* output, + fp4e2m1x2* output_t, + fp8e4m3* scales, + fp8e4m3* scales_t, + const float global_amax, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const size_t scales_stride_t, + const bool use_fast_math) +{ + std::vector input_t = create_transpose(input, rows, cols); + + quantize_nvfp4(input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4(input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_fast_math); +} + +void compare_nvfp4_tensors(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8) { + constexpr int max_mismatches_to_print = 3; + + std::vector mismatch_messages; + size_t total_mismatches = 0; + + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + + const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); + if (mismatch) { + total_mismatches++; + // Optional: limit number of detailed messages to avoid overwhelming output + if (total_mismatches <= max_mismatches_to_print) { + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } + } + } + } + + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > max_mismatches_to_print) { + std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print) + << " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; + } +} + +// Optional: Function to dump tensor data to files for detailed analysis +void dump_nvfp4_tensor_data(const std::string& prefix, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + std::string test_file = prefix + "_test.txt"; + std::string ref_file = prefix + "_ref.txt"; + std::string diff_file = prefix + "_diff.txt"; + + std::ofstream test_out(test_file); + std::ofstream ref_out(ref_file); + std::ofstream diff_out(diff_file); + + if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + const int pos = idx + k; + + test_out << "pos[" << pos << "] = " << t << std::endl; + ref_out << "pos[" << pos << "] = " << r << std::endl; + diff_out << "pos[" << pos << "] test=" << t << " ref=" << r + << " abs_diff=" << fabs(t - r) + << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl; + } + } + } + std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl; + } else { + std::cout << "WARNING: Could not open files for tensor data dump" << std::endl; + } +} + +void print_detailed_tensor_comparison(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n", + name.c_str(), rows, cols, rows * cols); + + const int total_elements = rows * cols; + const int check_count = 128; + + printf("--- FIRST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = 0; i < std::min(check_count, total_elements); ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + + if (total_elements > 2 * check_count) { + printf("\n--- LAST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = total_elements - check_count; i < total_elements; ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + } + printf("==================================\n"); +} + +void compareResults_nvfp4(const Tensor &test, + const void *ref, const void *ref_t, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + + // Print detailed element-by-element comparison + // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); + // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows); + + // Optionally dump tensor data to files for detailed analysis + if (dump_data) { + dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + + compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); +} + +template +void performTest(const ShapeRepresentation shape_rep, + const size_t num_tensors, + const std::vector& logical_shape, + const std::vector& first_dims, + const std::vector& last_dims, + const std::vector& offsets, + const bool use_fast_math) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + const size_t total_elts = offsets.back(); + std::vector grouped_input(total_elts); + + // Validate logical shape against the offsets-based flattened size. + size_t expected_total_elts = logical_shape[0] * logical_shape[1]; + if (shape_rep == VARYING_LAST_DIM) { + expected_total_elts = logical_shape[0] + * std::accumulate(last_dims.begin(), last_dims.end(), static_cast(0)); + } + ASSERT_EQ(expected_total_elts, total_elts); + + Tensor grouped_input_tensor("grouped_input", std::vector{total_elts}, itype); + fillCase(&grouped_input_tensor, InputsFillCase::uniform); + std::copy(grouped_input_tensor.rowwise_cpu_dptr(), + grouped_input_tensor.rowwise_cpu_dptr() + total_elts, + grouped_input.begin()); + + const double atol = 1.0E-6; + const double rtol = 1.0E-6; + + QuantizationConfigWrapper quant_config; + quant_config.set_use_fast_math(use_fast_math); + quant_config.set_stochastic_rounding(false); + quant_config.set_nvfp4_2d_quantization(false); + + // Grouped NVFP4 kernel is not available yet. + // Validate grouped metadata/configuration by quantizing each tensor independently. + for (size_t t = 0; t < num_tensors; ++t) { + const size_t rows = first_dims[t]; + const size_t cols = last_dims[t]; + const size_t tensor_offset = offsets[t]; + const size_t tensor_numel = rows * cols; + ASSERT_EQ(offsets[t + 1] - offsets[t], tensor_numel); + ASSERT_LE(tensor_offset + tensor_numel, total_elts); + + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); + std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); + std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); + std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + + float amax = 0.0f; + for (size_t idx = 0; idx < tensor_numel; ++idx) { + amax = fmaxf(amax, fabsf(static_cast(grouped_input[tensor_offset + idx]))); + } + + Tensor input("input_tensor_" + std::to_string(t), std::vector{rows, cols}, itype); + std::copy(grouped_input.begin() + tensor_offset, + grouped_input.begin() + tensor_offset + tensor_numel, + input.rowwise_cpu_dptr()); + input.from_cpu(); + + Tensor output("output_tensor_" + std::to_string(t), std::vector{rows, cols}, otype, + true, true, NVTE_NVFP4_1D_SCALING); + output.set_scale(amax); + + compute_ref(grouped_input.data() + tensor_offset, + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + output.scale(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math); + + nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), + static_cast(rows), static_cast(cols), atol, rtol, true, false); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales_" + std::to_string(t), + output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + scale_mismatches_num); + + compare_scaling_factors("scales_t_" + std::to_string(t), + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); + } +} + +// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} +std::vector> grouped_input_config = { + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,160, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, +}; + +} // namespace + +class GroupedFusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam + , // Config + transformer_engine::DType, + bool>> {}; + +TEST_P(GroupedFusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const std::vector input_config = std::get<0>(GetParam()); + const DType input_type = std::get<1>(GetParam()); + const bool use_fast_math = std::get<2>(GetParam()); + + const ShapeRepresentation shape_rep = static_cast(input_config[0]); + const size_t num_tensors = input_config[1]; + const std::vector logical_shape = {input_config[2], input_config[3]}; + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + std::vector offsets(num_tensors + 1, 0); + for (size_t t = 0; t < num_tensors; ++t) { + switch (shape_rep) { + case SAME_BOTH_DIMS: { + first_dims[t] = logical_shape[0] / num_tensors; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_FIRST_DIM: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_LAST_DIM: { + first_dims[t] = logical_shape[0]; + last_dims[t] = input_config[t + 4]; + break; + } + case VARYING_BOTH_DIMS: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = input_config[t + (4 + num_tensors)]; + break; + } + } + offsets[t + 1] = offsets[t] + first_dims[t] * last_dims[t]; + + // FP4 kernels in this test assume 16-wide chunks and packed pairs. + if ((first_dims[t] % 16 != 0) || (last_dims[t] % 16 != 0)) { + GTEST_SKIP(); + } + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + performTest(shape_rep, num_tensors, logical_shape, + first_dims, last_dims, offsets, use_fast_math); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedFusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(grouped_input_config), + ::testing::Values(DType::kBFloat16), + ::testing::Values(false)), + [](const testing::TestParamInfo& info) { + std::string name = "CAST_ONLY"; + const std::vector input = std::get<0>(info.param); + + switch (static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "_SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "_VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "_VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "_VARYING_BOTH_DIMS"; break; + }; + + name += "_N_" + std::to_string(input[1]); + name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]); + name += "_" + test::typeName(std::get<1>(info.param)); + if (std::get<2>(info.param)) { + name += "_FAST_SCALING"; + } + return name; + }); From 0c5849c4cbcdbd79ac1a86585e620bf5dde5030b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 14:08:06 +0000 Subject: [PATCH 26/31] Initial kernel draft Signed-off-by: Oleg Goncharov --- ...roup_quantize_transpose_nvfp4_tuned_1D.cuh | 1190 +++++++++++++++++ 1 file changed, 1190 insertions(+) create mode 100644 transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh diff --git a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh new file mode 100644 index 0000000000..905c834558 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh @@ -0,0 +1,1190 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file group_quantize_transpose_nvfp4_tuned_1D.cuh + * \brief Tuned grouped kernel to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ + +#include +#include +#include +#include + +#include "../../../common.h" +#include "../../../util/cuda_runtime.h" +#include "../../../util/math.h" +#include "../../../util/ptx.cuh" +#include "../../../utils.cuh" +#include "../core/common.cuh" +#include "../core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace group_quantize_transpose_tuned_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; +using namespace dispatch::common; + +#if FP4_TYPE_SUPPORTED + +struct TunableConfig { + static constexpr int CHUNK_DIM_Y = 128; + static constexpr int CHUNK_DIM_X = 128; + static constexpr int PREFETCH_STAGES = 1; + static constexpr bool PERSISTENT = true; + static constexpr int STATIC_PERSISTENT_BLOCKS_PER_SM = 4; +}; + +static_assert(!TunableConfig::PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0), + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); + +constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_t[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; + +constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) +constexpr int THREADS_NUM = 128; +constexpr int ELTS_PER_THREAD = 16; +constexpr int TILE_DIM_Y = 64; +constexpr int TILE_DIM_X = 64; + +static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0"); + +static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && + "Unbalanced threads workload\0"); + +static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) && + "Chunk size Y must be evenly divisible by the tile size Y\0"); +static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) && + "Chunk size X must be evenly divisible by the tile size X\0"); + +static_assert((TILE_DIM_Y % SCALE_DIM == 0) && + "Tile size Y must be evenly divisible by the scale dim\0"); +static_assert((TILE_DIM_X % SCALE_DIM == 0) && + "Tile size X must be evenly divisible by the scale dim\0"); + +constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X; + +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; + +constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM; + +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; + +constexpr int STAGES_Y = TILES_Y; +constexpr int STAGES_X = TILES_X; +constexpr int STAGES = STAGES_Y * STAGES_X; + +constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM_IN = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_DIM_X = TILE_DIM_X; +constexpr int BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr int BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr int BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; +constexpr int BUFF_IN_ELTS_NUM = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr int BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr int BUFF_OUT_TR_DIM_Y = BUFF_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; + +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; + +constexpr int THREADS_X_TR = TILE_DIM_X / 2; +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; + +constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; +constexpr int ITERATIONS_TR = SCALES_PER_TILE_Y / THREADS_Y_TR; +static_assert(ITERATIONS_TR >= 1 && "Number of transpose iterations should be >=1\0"); +static_assert((SCALES_PER_TILE_Y % THREADS_Y_TR == 0) && + "Partial transpose iterations are not supported\0"); + +constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_TR_DIM_X / ITERATIONS_TR / STAGES; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; + +using IType = bf16; +using IType2 = typename ptx::FPx2; +using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; + +template +struct SCALING_COEFFICIENT_TYPE {}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = float; +}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = bf16; +}; + +__device__ __forceinline__ float get_amax_of_pair(const IType2 pair) { + return static_cast(__hmax(__habs(pair.x), __habs(pair.y))); +} + +// Compute "correct" per-block encoding scaling factor +template +__device__ __forceinline__ SF_TYPE +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported."); +} + +template <> +__device__ __forceinline__ float compute_nvfp4_scaling_coefficient( + const nvfp4_scale_t S_dec_block, const float S_enc) { + const float S_dec = 1.0f / S_enc; + const float scale_rcp = + fminf(1.0f / (static_cast(S_dec_block) * S_dec), detail::TypeExtrema::max); + return scale_rcp; +} + +template <> +__device__ __forceinline__ bf16 +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + const float scale_rcp = + fminf(S_enc / (static_cast(S_dec_block)), detail::TypeExtrema::max); + return static_cast(scale_rcp); +} + +template +__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, + const float S_enc_colwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, + uint4 &random_uint4, int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn2x = *reinterpret_cast(sIn_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; + const int tid_X_colwise = thread_lane; + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + // Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; +#pragma unroll + for (int w = 0; w < 2; ++w) { + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + } + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +template +__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, + const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out, RNG_t &rng, uint4 &random_uint4, + int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_Y = tid_Y_rowwise; + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y; + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read (cache) input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + // Load elements + __uint128_t &elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = get_amax_of_pair(thread_amax_2x); + + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + +// Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + uint32_t out_x8; + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t block_Y, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = block_Y * TunableConfig::CHUNK_DIM_Y; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } + + size_t low = 1; + size_t hi = num_tensors; // [low, hi] + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + return low - 1; +} + +__device__ __forceinline__ size_t get_tensor_rows_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { + size_t rows_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + rows_num = first_logical_dim / num_tensors; + break; + case ShapeRepresentation::VARYING_LAST_DIM: + rows_num = first_logical_dim; + break; + case ShapeRepresentation::VARYING_FIRST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + rows_num = static_cast(first_dims_ptr[tensor_id]); + break; + } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } + return rows_num; +} + +__device__ __forceinline__ size_t get_tensor_cols_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { + size_t cols_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + case ShapeRepresentation::VARYING_FIRST_DIM: + cols_num = last_logical_dim; + break; + case ShapeRepresentation::VARYING_LAST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + cols_num = static_cast(last_dims_ptr[tensor_id]); + if (cols_num % 128 != 0) { + NVTE_DEVICE_ERROR( + "For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); + } + break; + } + return cols_num; +} + +__device__ __forceinline__ size_t get_tensor_base_offset( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const size_t last_logical_dim, const size_t num_tensors, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return tensor_id * rows_per_tensor * last_logical_dim; + } + return static_cast(offsets_ptr[tensor_id]); +} + +struct JobDescriptor { + size_t block_id = 0; + size_t block_global_offset = 0; + size_t tensor_id = 0; + size_t rows = 0; + size_t cols = 0; +}; + +struct BlockDescriptor { + size_t tensor_base = 0; + size_t block_id_Y = 0; + size_t block_id_X = 0; + size_t block_offset_Y = 0; + size_t block_offset_X = 0; +}; + +__device__ __forceinline__ JobDescriptor decode_job( + const ShapeRepresentation shape_rep, const bool use_single_work_grid, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, + const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr) { + JobDescriptor job{}; + job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); + job.block_global_offset = + use_single_work_grid + ? (static_cast(ctaid_Y) * TunableConfig::CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * TunableConfig::CHUNK_DIM_X) + : (job.block_id * TunableConfig::CHUNK_DIM_Y * TunableConfig::CHUNK_DIM_X); + job.tensor_id = + get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, static_cast(ctaid_Y), + first_logical_dim, last_logical_dim, offsets_ptr); + job.rows = + get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + return job; +} + +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, + const ShapeRepresentation shape_rep, + const size_t total_work_blocks, + const int64_t *const __restrict__ offsets_ptr) { + bool is_valid = (job.block_id < total_work_blocks) && (job.rows != 0) && (job.cols != 0); + if (!is_valid || shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + return is_valid; + } + + const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); + if (job.block_global_offset >= tensor_end_offset) { + return false; + } + + return true; +} + +__device__ __forceinline__ BlockDescriptor decode_block( + const JobDescriptor &job, const ShapeRepresentation shape_rep, const bool use_single_work_grid, + const size_t first_logical_dim, const size_t last_logical_dim, const size_t num_tensors, + const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr) { + BlockDescriptor block{}; + block.tensor_base = get_tensor_base_offset(job.tensor_id, shape_rep, first_logical_dim, + last_logical_dim, num_tensors, offsets_ptr); + + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(TunableConfig::CHUNK_DIM_X)); + if (use_single_work_grid) { + block.block_id_X = static_cast(ctaid_X); + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t rows_per_tensor = first_logical_dim / num_tensors; + const size_t blocks_Y_per_tensor = + DIVUP(rows_per_tensor, static_cast(TunableConfig::CHUNK_DIM_Y)); + block.block_id_Y = static_cast(ctaid_Y) - job.tensor_id * blocks_Y_per_tensor; + } else { + const size_t tensor_base_row = block.tensor_base / job.cols; + block.block_id_Y = + static_cast(ctaid_Y) - tensor_base_row / static_cast(TunableConfig::CHUNK_DIM_Y); + } + } else { + const size_t block_id_in_current_tensor = + job.block_id - block.tensor_base / (TunableConfig::CHUNK_DIM_Y * TunableConfig::CHUNK_DIM_X); + block.block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + block.block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + } + + block.block_offset_Y = block.block_id_Y * TunableConfig::CHUNK_DIM_Y; + block.block_offset_X = block.block_id_X * TunableConfig::CHUNK_DIM_X; + return block; +} + +__device__ __forceinline__ uintptr_t get_pointer_with_offset_bits(const uintptr_t base_ptr, + const size_t offset_elts, + const size_t data_type_bits) { + const size_t offset_bits = offset_elts * data_type_bits; + if (offset_bits % 8 != 0) { + NVTE_DEVICE_ERROR("Data offset is not byte-aligned."); + } + return base_ptr + offset_bits / 8; +} + +__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, + CUtensorMap *global_tensor_map, + const uintptr_t global_data_ptr, + const size_t global_dim_Y, + const size_t global_dim_X, + const size_t data_type_bits) { + __shared__ CUtensorMap shared_tensor_map; + shared_tensor_map = base_tensor_map; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + const size_t global_stride_bits = global_dim_X * data_type_bits; + if (global_stride_bits % 8 != 0) { + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be byte-aligned."); + } + const size_t global_stride_bytes = global_stride_bits / 8; + if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) { + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be 16B aligned."); + } + if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { + NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned."); + } + asm volatile( + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n\t" + "}\n" + : + : "l"(reinterpret_cast(&shared_tensor_map)), "l"(global_data_ptr), + "r"(static_cast(global_dim_Y)), "r"(static_cast(global_dim_X)), + "l"(static_cast(global_stride_bytes)) + : "memory"); + *global_tensor_map = shared_tensor_map; + } else { + NVTE_DEVICE_ERROR( + "tensormap.replace is architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +template +__global__ void update_tma_descriptors( + const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_output, + const __grid_constant__ CUtensorMap base_tensor_map_output_t, + const InType *const __restrict__ input_data_ptr, const void *const output_data_ptr, + const void *const output_t_data_ptr, const ShapeRepresentation shape_rep, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise) { + const bool leading_thread = (threadIdx.x == 0); + const size_t tensor_id = blockIdx.x; + if (!leading_thread || tensor_id >= num_tensors) { + return; + } + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t offset_elts = get_tensor_base_offset(tensor_id, shape_rep, first_logical_dim, + last_logical_dim, num_tensors, offsets_ptr); + + { + const uintptr_t global_data_ptr = get_pointer_with_offset_bits( + reinterpret_cast(input_data_ptr), offset_elts, TypeInfo::size); + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], global_data_ptr, + rows, cols, TypeInfo::size); + } + + if (rowwise) { + const uintptr_t global_data_ptr = get_pointer_with_offset_bits( + reinterpret_cast(output_data_ptr), offset_elts, 4); + modify_base_tensor_map(base_tensor_map_output, &g_tensor_maps_output[tensor_id], global_data_ptr, + rows, cols, 4); + } + + if (colwise) { + const uintptr_t global_data_ptr = get_pointer_with_offset_bits( + reinterpret_cast(output_t_data_ptr), offset_elts, 4); + modify_base_tensor_map(base_tensor_map_output_t, &g_tensor_maps_output_t[tensor_id], + global_data_ptr, cols, rows, 4); + } +} + +__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map)); +#else + NVTE_DEVICE_ERROR("fence_acquire_tensormap is only supported on SM 9.0+."); +#endif +} + +template +__global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tuned_1D_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_static, + const __grid_constant__ CUtensorMap tensor_map_output_t_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t work_blocks_X, const size_t work_blocks_Y, + const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + (void)tensor_map_input_static; + (void)tensor_map_output_static; + (void)tensor_map_output_t_static; + + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const size_t launch_block_id = + static_cast(blockIdx.y) * static_cast(gridDim.x) + static_cast(blockIdx.x); + const size_t rng_sequence = threadIdx.x + launch_block_id * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG_t rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; + + const bool leading_thread = (threadIdx.x == 0); + const bool use_single_work_grid = + (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int in_mem = buff_size_aligned_in; + constexpr int out_mem_rowwise_data = buff_size_aligned_out; + constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType *sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2 *sOut_ptr = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *sOut_tr_ptr = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + + nvfp4_scale_t *sSFrowwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *sSFcolwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const float S_enc_rowwise = + (amax_rowwise_ptr == nullptr) + ? 1.0f + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + const float S_enc_colwise = + (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + if constexpr (TunableConfig::PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = static_cast(gridDim.x) * static_cast(gridDim.y); + static_next_block_id = launch_block_id + static_block_stride; + } + + bool job_finished = false; + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + bool has_prefetched_current_job = true; + + { + const JobDescriptor first_job = decode_job(shape_rep, use_single_work_grid, num_tensors, + first_logical_dim, last_logical_dim, work_blocks_X, + ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, + last_dims_ptr); + if (!is_job_valid(first_job, shape_rep, total_work_blocks, offsets_ptr)) { + return; + } + const BlockDescriptor first_block = + decode_block(first_job, shape_rep, use_single_work_grid, first_logical_dim, + last_logical_dim, num_tensors, ctaid_X, ctaid_Y, offsets_ptr); + const CUtensorMap &tensor_map_input = g_tensor_maps_input[first_job.tensor_id]; + if (leading_thread) { + fence_acquire_tensormap(&tensor_map_input); + } +#pragma unroll + for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) { + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + const int global_offset_Y = static_cast(first_block.block_offset_Y) + stage_offset_Y; + const int global_offset_X = static_cast(first_block.block_offset_X) + stage_offset_X; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[stage]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + uint64_t *barrier = &IN_buff_readable_mbar[stage]; + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + } + } + + while (!job_finished) { + const JobDescriptor current_job = decode_job(shape_rep, use_single_work_grid, num_tensors, + first_logical_dim, last_logical_dim, work_blocks_X, + ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, + last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); + if (!current_job_is_valid) { + if (has_prefetched_current_job) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); + } + break; + } + + const BlockDescriptor current_block = + decode_block(current_job, shape_rep, use_single_work_grid, first_logical_dim, + last_logical_dim, num_tensors, ctaid_X, ctaid_Y, offsets_ptr); + + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; + const size_t block_offset_Y_tr = block_offset_X; + const size_t block_offset_X_tr = block_offset_Y; + + const size_t chunk_rows = rows - block_offset_Y; + const size_t chunk_cols = cols - block_offset_X; + + const size_t scales_block_offset_Y_rowwise = current_block.block_id_Y * TunableConfig::CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = current_block.block_id_X * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_tr = current_block.block_id_X * TunableConfig::CHUNK_DIM_X; + const size_t scales_block_offset_X_tr = current_block.block_id_Y * SCALES_PER_CHUNK_Y; + + nvfp4_scale_t *const scales_rowwise = scales_ptr + current_block.tensor_base / SCALE_DIM; + nvfp4_scale_t *const scales_colwise = + RETURN_TRANSPOSE ? (scales_t_ptr + current_block.tensor_base / SCALE_DIM) : nullptr; + const size_t scale_stride = + DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(SCALE_DIM)), static_cast(4)); + const size_t scale_stride_t = + DIVUP_TO_MULTIPLE(DIVUP(rows, static_cast(SCALE_DIM)), static_cast(4)); + + const CUtensorMap &tensor_map_input = g_tensor_maps_input[current_job.tensor_id]; + const CUtensorMap &tensor_map_output = g_tensor_maps_output[current_job.tensor_id]; + const CUtensorMap &tensor_map_output_t = g_tensor_maps_output_t[current_job.tensor_id]; + + if (leading_thread) { + fence_acquire_tensormap(&tensor_map_input); + fence_acquire_tensormap(&tensor_map_output); + if constexpr (RETURN_TRANSPOSE) { + fence_acquire_tensormap(&tensor_map_output_t); + } + } + + bool prefetched_next_job = false; +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + bool allow_next_job_prefetch = true; + JobDescriptor prefetch_job = current_job; + BlockDescriptor prefetch_block = current_block; + + if (stage == STAGES - TunableConfig::PREFETCH_STAGES) { + if constexpr (TunableConfig::PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; + } else { + ctaid_X = 0; + ctaid_Y = static_cast(work_blocks_Y); + allow_next_job_prefetch = false; + } + } else { + ctaid_X = -1; + ctaid_Y = -1; + } + if constexpr (!TunableConfig::PERSISTENT) { + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } + } + } + + if ((stage >= STAGES - TunableConfig::PREFETCH_STAGES) && allow_next_job_prefetch && + !job_finished) { + prefetch_job = decode_job(shape_rep, use_single_work_grid, num_tensors, first_logical_dim, + last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, + first_dims_ptr, last_dims_ptr); + allow_next_job_prefetch = + is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); + if (allow_next_job_prefetch) { + prefetch_block = decode_block(prefetch_job, shape_rep, use_single_work_grid, + first_logical_dim, last_logical_dim, num_tensors, ctaid_X, + ctaid_Y, offsets_ptr); + } + } + + if ((stage < STAGES - TunableConfig::PREFETCH_STAGES) || + (allow_next_job_prefetch && !job_finished)) { + const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES; + const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; + const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; + const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; + const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; + + if (stage >= STAGES - TunableConfig::PREFETCH_STAGES) { + prefetched_next_job = true; + } + + const int global_offset_Y = + static_cast(prefetch_block.block_offset_Y) + next_prefetch_stage_offset_Y; + const int global_offset_X = + static_cast(prefetch_block.block_offset_X) + next_prefetch_stage_offset_X; + + const CUtensorMap &prefetch_tensor_map_input = g_tensor_maps_input[prefetch_job.tensor_id]; + if (leading_thread && stage == STAGES - TunableConfig::PREFETCH_STAGES) { + fence_acquire_tensormap(&prefetch_tensor_map_input); + } + + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t *src = reinterpret_cast(&prefetch_tensor_map_input); + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + ptx::fence_proxy_async_shared_cta(); + } + + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); + + rowwise_scaling( + sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, + rng, random_uint4, rnd_idx); + if constexpr (RETURN_TRANSPOSE) { + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in, + buff_out_tr, rng, random_uint4, rnd_idx); + } + + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (leading_thread) { + const int global_offset_Y = static_cast(block_offset_Y) + stage_offset_Y; + const int global_offset_X = static_cast(block_offset_X) + stage_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&sOut[buff_out])); + + if constexpr (RETURN_TRANSPOSE) { + const int global_offset_Y_tr = static_cast(block_offset_Y_tr) + stage_offset_X; + const int global_offset_X_tr = static_cast(block_offset_X_tr) + stage_offset_Y; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); + } + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM_IN; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } + has_prefetched_current_job = prefetched_next_job; + + { + using RowwiseScalesVec = Vec; + const int rowwise_count = min(SCALES_PER_CHUNK_X, static_cast(chunk_cols / SCALE_DIM)); + for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + RowwiseScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_rowwise[scale_idx_global], 0, rowwise_count); + } + } + + if constexpr (RETURN_TRANSPOSE) { + using ColwiseScalesVec = Vec; + const int colwise_count = min(SCALES_PER_CHUNK_Y, static_cast(chunk_rows / SCALE_DIM)); + for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ColwiseScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = row_tr_global * scale_stride_t + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_colwise[scale_idx_global], 0, colwise_count); + } + } + } + + if (!job_finished) { + __syncthreads(); + } + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace group_quantize_transpose_tuned_kernel + +inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const Tensor *noop, + GroupedTensor *output, + const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace group_quantize_transpose_tuned_kernel; + using namespace ptx; + + const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + const bool return_transpose = output->has_columnwise_data(); + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "Grouped NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->dtype()), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + if (return_transpose) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated."); + } + + ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + const bool use_single_work_grid = + (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t elts_total = first_logical_dim * last_logical_dim; + const size_t num_tensors = input->num_tensors; + + NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, + "Number of tensors in a group is larger than the MAX number of supported " + "descriptors (64)."); + if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { + NVTE_CHECK(first_logical_dim % 128 == 0, + "First logical dimension of a grouped tensor must be divisible by 128."); + } + NVTE_CHECK(first_logical_dim % 32 == 0, + "Number of tensor rows must be a multiple of 32."); + NVTE_CHECK(last_logical_dim % 32 == 0, + "Number of tensor cols must be a multiple of 32."); + + size_t work_blocks_X = 0; + size_t work_blocks_Y = 0; + if (use_single_work_grid) { + work_blocks_Y = DIVUP(first_logical_dim, static_cast(TunableConfig::CHUNK_DIM_Y)); + work_blocks_X = DIVUP(last_logical_dim, static_cast(TunableConfig::CHUNK_DIM_X)); + } else { + work_blocks_Y = 1; + work_blocks_X = DIVUP( + elts_total, + static_cast(TunableConfig::CHUNK_DIM_Y * TunableConfig::CHUNK_DIM_X)); + } + + size_t launch_blocks_X = work_blocks_X; + size_t launch_blocks_Y = work_blocks_Y; + if constexpr (TunableConfig::PERSISTENT) { + const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); + const size_t static_grid_size = + sm_num * static_cast(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM); + NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); + launch_blocks_X = static_grid_size; + launch_blocks_Y = 1; + } + + const dim3 grid(launch_blocks_X, launch_blocks_Y); + const int block_size = THREADS_NUM; + + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(output->last_dims.dptr); + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_t_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, sizeof(IType) * 8); + create_2D_tensor_map(tensor_map_output, output->data, first_logical_dim, last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, last_logical_dim, + first_logical_dim, BUFF_DIM_X, BUFF_DIM_Y, first_logical_dim, 0, 4); + } + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + const int in_mem = buff_size_aligned_in; + const int out_data_mem = buff_size_aligned_out; + const int out_data_transpose_mem = return_transpose ? buff_size_aligned_out_t : 0; + const int out_scales_mem = buff_size_scales; + const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; + const int out_mem = out_data_mem + out_data_transpose_mem; + const int dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + const void *const output_dptr = output->data.dptr; + const void *const output_t_dptr = return_transpose ? output->columnwise_data.dptr : nullptr; + + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, input_dptr, output_dptr, + output_t_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, true, return_transpose); + NVTE_CHECK_CUDA(cudaGetLastError()); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + group_quantize_transpose_nvfp4_tuned_1D_kernel; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, shape_rep, + num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, + last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, work_blocks_X, work_blocks_Y, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); + }););); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ From 178a7c487b866854ed87f235127ece9284eed7be Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 14:14:23 +0000 Subject: [PATCH 27/31] Refactoring Signed-off-by: Oleg Goncharov --- ...roup_quantize_transpose_nvfp4_tuned_1D.cuh | 131 +++++++++--------- 1 file changed, 63 insertions(+), 68 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh index 905c834558..923ce85a4a 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh @@ -45,7 +45,13 @@ struct TunableConfig { static constexpr int STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; -static_assert(!TunableConfig::PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0), +constexpr int CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr int CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr int PREFETCH_STAGES = TunableConfig::PREFETCH_STAGES; +constexpr bool PERSISTENT = TunableConfig::PERSISTENT; +constexpr int STATIC_PERSISTENT_BLOCKS_PER_SM = TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; + +static_assert(!PERSISTENT || (STATIC_PERSISTENT_BLOCKS_PER_SM > 0), "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; @@ -64,9 +70,9 @@ static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0") static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && "Unbalanced threads workload\0"); -static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) && +static_assert((CHUNK_DIM_Y % TILE_DIM_Y == 0) && "Chunk size Y must be evenly divisible by the tile size Y\0"); -static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) && +static_assert((CHUNK_DIM_X % TILE_DIM_X == 0) && "Chunk size X must be evenly divisible by the tile size X\0"); static_assert((TILE_DIM_Y % SCALE_DIM == 0) && @@ -74,13 +80,13 @@ static_assert((TILE_DIM_Y % SCALE_DIM == 0) && static_assert((TILE_DIM_X % SCALE_DIM == 0) && "Tile size X must be evenly divisible by the scale dim\0"); -constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y; -constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X; +constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X; constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; -constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM; -constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; @@ -89,7 +95,7 @@ constexpr int STAGES_Y = TILES_Y; constexpr int STAGES_X = TILES_X; constexpr int STAGES = STAGES_Y * STAGES_X; -constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM = PREFETCH_STAGES + 1; constexpr int BUFFS_NUM_IN = BUFFS_NUM; constexpr int BUFFS_NUM_OUT = BUFFS_NUM; constexpr int BUFFS_NUM_OUT_TR = 2; @@ -135,7 +141,7 @@ constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_TR_DIM_X / ITERATIONS_TR / STAGES; static_assert(BUFF_DIM_Y >= SCALE_DIM && "Number of buffer rows must be greater or equal to the size of the columwise " "scaling block\0"); -static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && "Number of buffer rows must be greater or equal to the number of rowwise " "processing threads in Y dimension\0"); @@ -152,8 +158,8 @@ using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; -using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; -using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using ScalesType2D = nvfp4_scale_t[CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; template @@ -366,7 +372,7 @@ __device__ __forceinline__ size_t get_current_tensor_id( const size_t block_Y, const size_t first_logical_dim, const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr) { if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { - const size_t current_row = block_Y * TunableConfig::CHUNK_DIM_Y; + const size_t current_row = block_Y * CHUNK_DIM_Y; const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; } @@ -464,11 +470,10 @@ __device__ __forceinline__ JobDescriptor decode_job( const int64_t *const __restrict__ last_dims_ptr) { JobDescriptor job{}; job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - job.block_global_offset = - use_single_work_grid - ? (static_cast(ctaid_Y) * TunableConfig::CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * TunableConfig::CHUNK_DIM_X) - : (job.block_id * TunableConfig::CHUNK_DIM_Y * TunableConfig::CHUNK_DIM_X); + job.block_global_offset = use_single_work_grid + ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (job.block_id * CHUNK_DIM_Y * CHUNK_DIM_X); job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, static_cast(ctaid_Y), first_logical_dim, last_logical_dim, offsets_ptr); @@ -504,28 +509,28 @@ __device__ __forceinline__ BlockDescriptor decode_block( block.tensor_base = get_tensor_base_offset(job.tensor_id, shape_rep, first_logical_dim, last_logical_dim, num_tensors, offsets_ptr); - const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(TunableConfig::CHUNK_DIM_X)); + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(CHUNK_DIM_X)); if (use_single_work_grid) { block.block_id_X = static_cast(ctaid_X); if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t rows_per_tensor = first_logical_dim / num_tensors; const size_t blocks_Y_per_tensor = - DIVUP(rows_per_tensor, static_cast(TunableConfig::CHUNK_DIM_Y)); + DIVUP(rows_per_tensor, static_cast(CHUNK_DIM_Y)); block.block_id_Y = static_cast(ctaid_Y) - job.tensor_id * blocks_Y_per_tensor; } else { const size_t tensor_base_row = block.tensor_base / job.cols; block.block_id_Y = - static_cast(ctaid_Y) - tensor_base_row / static_cast(TunableConfig::CHUNK_DIM_Y); + static_cast(ctaid_Y) - tensor_base_row / static_cast(CHUNK_DIM_Y); } } else { const size_t block_id_in_current_tensor = - job.block_id - block.tensor_base / (TunableConfig::CHUNK_DIM_Y * TunableConfig::CHUNK_DIM_X); + job.block_id - block.tensor_base / (CHUNK_DIM_Y * CHUNK_DIM_X); block.block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; block.block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; } - block.block_offset_Y = block.block_id_Y * TunableConfig::CHUNK_DIM_Y; - block.block_offset_X = block.block_id_X * TunableConfig::CHUNK_DIM_X; + block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y; + block.block_offset_X = block.block_id_X * CHUNK_DIM_X; return block; } @@ -683,7 +688,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu constexpr int out_mem_rowwise_data = buff_size_aligned_out; constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( - TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); extern __shared__ unsigned char dynamic_shmem[]; unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); @@ -730,7 +735,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu int32_t ctaid_Y = static_cast(blockIdx.y); size_t static_next_block_id = 0; size_t static_block_stride = 0; - if constexpr (TunableConfig::PERSISTENT) { + if constexpr (PERSISTENT) { if (launch_block_id >= total_work_blocks) { return; } @@ -763,7 +768,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu fence_acquire_tensormap(&tensor_map_input); } #pragma unroll - for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) { + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { const int stage_Y = stage / STAGES_X; const int stage_X = stage % STAGES_X; const int stage_offset_Y = stage_Y * TILE_DIM_Y; @@ -793,7 +798,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); + ptx::cp_async_bulk_wait_group_read(); } break; } @@ -812,9 +817,9 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu const size_t chunk_rows = rows - block_offset_Y; const size_t chunk_cols = cols - block_offset_X; - const size_t scales_block_offset_Y_rowwise = current_block.block_id_Y * TunableConfig::CHUNK_DIM_Y; + const size_t scales_block_offset_Y_rowwise = current_block.block_id_Y * CHUNK_DIM_Y; const size_t scales_block_offset_X_rowwise = current_block.block_id_X * SCALES_PER_CHUNK_X; - const size_t scales_block_offset_Y_tr = current_block.block_id_X * TunableConfig::CHUNK_DIM_X; + const size_t scales_block_offset_Y_tr = current_block.block_id_X * CHUNK_DIM_X; const size_t scales_block_offset_X_tr = current_block.block_id_Y * SCALES_PER_CHUNK_Y; nvfp4_scale_t *const scales_rowwise = scales_ptr + current_block.tensor_base / SCALE_DIM; @@ -849,8 +854,8 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu JobDescriptor prefetch_job = current_job; BlockDescriptor prefetch_block = current_block; - if (stage == STAGES - TunableConfig::PREFETCH_STAGES) { - if constexpr (TunableConfig::PERSISTENT) { + if (stage == STAGES - PREFETCH_STAGES) { + if constexpr (PERSISTENT) { if (static_next_block_id < total_work_blocks) { ctaid_X = static_cast(static_next_block_id % work_blocks_X); ctaid_Y = static_cast(static_next_block_id / work_blocks_X); @@ -864,14 +869,14 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu ctaid_X = -1; ctaid_Y = -1; } - if constexpr (!TunableConfig::PERSISTENT) { + if constexpr (!PERSISTENT) { if (ctaid_X == -1 && ctaid_Y == -1) { job_finished = true; } } } - if ((stage >= STAGES - TunableConfig::PREFETCH_STAGES) && allow_next_job_prefetch && + if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { prefetch_job = decode_job(shape_rep, use_single_work_grid, num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, @@ -885,16 +890,16 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu } } - if ((stage < STAGES - TunableConfig::PREFETCH_STAGES) || + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { - const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM; - const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES; + const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; - if (stage >= STAGES - TunableConfig::PREFETCH_STAGES) { + if (stage >= STAGES - PREFETCH_STAGES) { prefetched_next_job = true; } @@ -904,7 +909,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu static_cast(prefetch_block.block_offset_X) + next_prefetch_stage_offset_X; const CUtensorMap &prefetch_tensor_map_input = g_tensor_maps_input[prefetch_job.tensor_id]; - if (leading_thread && stage == STAGES - TunableConfig::PREFETCH_STAGES) { + if (leading_thread && stage == STAGES - PREFETCH_STAGES) { fence_acquire_tensormap(&prefetch_tensor_map_input); } @@ -922,7 +927,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); + ptx::cp_async_bulk_wait_group_read(); rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, @@ -962,7 +967,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu { using RowwiseScalesVec = Vec; const int rowwise_count = min(SCALES_PER_CHUNK_X, static_cast(chunk_cols / SCALE_DIM)); - for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) { + for (size_t row = threadIdx.x; row < CHUNK_DIM_Y; row += THREADS_NUM) { const size_t row_global = scales_block_offset_Y_rowwise + row; if (row_global < rows) { RowwiseScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); @@ -974,7 +979,7 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu if constexpr (RETURN_TRANSPOSE) { using ColwiseScalesVec = Vec; const int colwise_count = min(SCALES_PER_CHUNK_Y, static_cast(chunk_rows / SCALE_DIM)); - for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; row_tr += THREADS_NUM) { + for (size_t row_tr = threadIdx.x; row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { ColwiseScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); @@ -1068,21 +1073,19 @@ inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const size_t work_blocks_X = 0; size_t work_blocks_Y = 0; if (use_single_work_grid) { - work_blocks_Y = DIVUP(first_logical_dim, static_cast(TunableConfig::CHUNK_DIM_Y)); - work_blocks_X = DIVUP(last_logical_dim, static_cast(TunableConfig::CHUNK_DIM_X)); + work_blocks_Y = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); + work_blocks_X = DIVUP(last_logical_dim, static_cast(CHUNK_DIM_X)); } else { work_blocks_Y = 1; - work_blocks_X = DIVUP( - elts_total, - static_cast(TunableConfig::CHUNK_DIM_Y * TunableConfig::CHUNK_DIM_X)); + work_blocks_X = DIVUP(elts_total, static_cast(CHUNK_DIM_Y * CHUNK_DIM_X)); } size_t launch_blocks_X = work_blocks_X; size_t launch_blocks_Y = work_blocks_Y; - if constexpr (TunableConfig::PERSISTENT) { + if constexpr (PERSISTENT) { const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); const size_t static_grid_size = - sm_num * static_cast(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM); + sm_num * static_cast(STATIC_PERSISTENT_BLOCKS_PER_SM); NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); launch_blocks_X = static_grid_size; launch_blocks_Y = 1; @@ -1130,16 +1133,11 @@ inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; - constexpr int buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out = - DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out_t = - DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( - TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( - TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); const int in_mem = buff_size_aligned_in; const int out_data_mem = buff_size_aligned_out; @@ -1147,8 +1145,7 @@ inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const const int out_scales_mem = buff_size_scales; const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; const int out_mem = out_data_mem + out_data_transpose_mem; - const int dshmem_size = - in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + const int dshmem_size = in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; const IType *const input_dptr = reinterpret_cast(input->data.dptr); const void *const output_dptr = output->data.dptr; @@ -1160,14 +1157,12 @@ inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const first_dims_ptr, last_dims_ptr, true, return_transpose); NVTE_CHECK_CUDA(cudaGetLastError()); - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = - group_quantize_transpose_nvfp4_tuned_1D_kernel; + TRANSFORMER_ENGINE_SWITCH_CONDITION(use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION(use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, + { + auto kernel = group_quantize_transpose_nvfp4_tuned_1D_kernel + ; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); From b035b43d8c19168ce570bada33f4d742ce1b2eba Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 14:33:41 +0000 Subject: [PATCH 28/31] Added the kernel to the quantization dispatcher Signed-off-by: Oleg Goncharov --- .../common/cast/dispatch/quantize.cuh | 22 +++++++++++++++++++ ...roup_quantize_transpose_nvfp4_tuned_1D.cuh | 14 ++++++------ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f7823b4c58..29886d3e74 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh" #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" @@ -412,6 +413,16 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor workspace_tensor, stream); break; } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + NVTE_CHECK(input_tensor->dtype() == DType::kBFloat16, + "Optimized grouped NVFP4 kernel supports only BF16 input."); + nvfp4::group_quantize_transpose(input_tensor, noop_tensor, output_tensor, + &quant_config_cpp, stream); + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); } @@ -453,6 +464,17 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe stream); break; } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + NVTE_CHECK(grad_tensor->dtype() == DType::kBFloat16, + "Optimized grouped NVFP4 kernel supports only BF16 input."); + nvfp4::group_quantize_transpose(grad_tensor, noop_tensor, output_tensor, + &quant_config_cpp, stream); + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); } diff --git a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh index 923ce85a4a..1ebae56eb1 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh @@ -1009,10 +1009,9 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu #endif // FP4_TYPE_SUPPORTED } // namespace group_quantize_transpose_tuned_kernel -inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const Tensor *noop, - GroupedTensor *output, - const QuantizationConfig *quant_config, - cudaStream_t stream) { +inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *noop, + GroupedTensor *output, const QuantizationConfig *quant_config, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED using namespace group_quantize_transpose_tuned_kernel; using namespace ptx; @@ -1027,6 +1026,8 @@ inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const NVTE_CHECK(input->num_tensors == output->num_tensors, "Number of input and output tensors must be same."); NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(input->dtype() == DType::kBFloat16, + "Optimized grouped NVFP4 kernel supports only BF16 input."); NVTE_CHECK(output->has_data(), "Grouped NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->dtype()), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); @@ -1049,9 +1050,8 @@ inline void group_quantize_transpose_tuned_1D(const GroupedTensor *input, const shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; } - const bool use_single_work_grid = - (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || - shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + const bool use_single_work_grid = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS + || shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); const size_t first_logical_dim = input->logical_shape.data[0]; const size_t last_logical_dim = input->logical_shape.data[1]; From 9d727574e502e982cea930e0c7a26b8c4e725132 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 14:41:06 +0000 Subject: [PATCH 29/31] Isolated only the Group Quantize NVFP4 for compilation Signed-off-by: Oleg Goncharov --- .../common/cast/dispatch/quantize.cuh | 680 +++++++++--------- ...roup_quantize_transpose_nvfp4_tuned_1D.cuh | 2 +- 2 files changed, 341 insertions(+), 341 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 29886d3e74..a30a29f6be 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -31,282 +31,282 @@ namespace dispatch { template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - Tensor *output_tensor = convertNVTETensorCheck(output); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - mxfp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // Tensor *output_tensor = convertNVTETensorCheck(output); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_ACT) { + // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // mxfp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/input_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } template void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *grad_tensor = convertNVTETensorCheck(grad); - const Tensor *input_tensor = convertNVTETensor(input); - - Tensor *output_tensor = convertNVTETensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT) { - cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*grad_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = grad_tensor->flat_first_dim(); - int32_t cols = grad_tensor->flat_last_dim(); - auto dtype = grad_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *grad_tensor = convertNVTETensorCheck(grad); + // const Tensor *input_tensor = convertNVTETensor(input); + + // Tensor *output_tensor = convertNVTETensorCheck(output); + // Tensor *dbias_tensor = convertNVTETensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_DBIAS && !IS_DACT) { + // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::quantize( + // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*grad_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = grad_tensor->flat_first_dim(); + // int32_t cols = grad_tensor->flat_last_dim(); + // auto dtype = grad_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -315,64 +315,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - std::vector output_tensors; - for (size_t i = 0; i < num_tensors; ++i) { - output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - } - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Take the scaling mode of the first output tensor - auto scaling_mode = output_tensors[0]->scaling_mode; - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - // Skip checking output tensor list - // output list here is allowed to have empty tensor - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "2D quantization is not supported for group quantize."); - - // Launch NVFP4 group quantize kernel - nvfp4::group_quantize_transpose( - *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - &quant_config_cpp, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // std::vector output_tensors; + // for (size_t i = 0; i < num_tensors; ++i) { + // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + // } + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // // Take the scaling mode of the first output tensor + // auto scaling_mode = output_tensors[0]->scaling_mode; + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // // Skip checking output tensor list + // // output list here is allowed to have empty tensor + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + + // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + // "2D quantization is not supported for group quantize."); + + // // Launch NVFP4 group quantize kernel + // nvfp4::group_quantize_transpose( + // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + // &quant_config_cpp, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } template @@ -407,12 +407,12 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( - input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::group_quantize( + // input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, + // workspace_tensor, stream); + // break; + // } case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, @@ -458,12 +458,12 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe // Dispatch to quantization kernel depending on data format switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( - grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::group_quantize( + // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK((!IS_DBIAS && !IS_DACT), "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); diff --git a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh index 1ebae56eb1..dd3744b6b4 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh @@ -21,7 +21,7 @@ #include "../../../util/math.h" #include "../../../util/ptx.cuh" #include "../../../utils.cuh" -#include "../core/common.cuh" +#include "../../core/common.cuh" #include "../core_nvfp4.cuh" namespace transformer_engine { From da8da89d57b47d86df073f22a9e5c92bc30b077f Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 16:23:29 +0000 Subject: [PATCH 30/31] Fixed test suite and bug in scaling factors padding Signed-off-by: Oleg Goncharov --- .../test_cast_nvfp4_transpose_grouped.cu | 316 +++++++++++++----- .../common/cast/dispatch/quantize.cuh | 21 +- ...roup_quantize_transpose_nvfp4_tuned_1D.cuh | 96 ++++-- 3 files changed, 308 insertions(+), 125 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu b/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu index def327b8c9..5fb570347b 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu @@ -14,6 +14,7 @@ #include "../test_common.h" #include "transformer_engine/transformer_engine.h" #include +#include #include #include @@ -72,14 +73,14 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax, const bo // 1D Scaling: Original implementation with 1x16 blocks template -void quantize_nvfp4_1d(const InputType* const input, - fp4e2m1x2* const output, - fp8e4m3* const scales, - const size_t rows, - const size_t cols, - const size_t scales_stride, - const float global_amax, - const bool use_fast_math) { +void quantize_nvfp4(const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_fast_math) { // Compute a global encoding/decoding scaling factor for all S_dec_b const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); @@ -151,18 +152,6 @@ void quantize_nvfp4_1d(const InputType* const input, } } -template -void quantize_nvfp4(const InputType* const input, - fp4e2m1x2* const output, - fp8e4m3* const scales, - const size_t rows, - const size_t cols, - const size_t scales_stride, - const float global_amax, - const bool use_fast_math) { - quantize_nvfp4_1d(input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); -} - template void compute_ref(const InputType* input, fp4e2m1x2* output, @@ -178,8 +167,8 @@ void compute_ref(const InputType* input, { std::vector input_t = create_transpose(input, rows, cols); - quantize_nvfp4(input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); - quantize_nvfp4(input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_fast_math); + quantize_nvfp4(input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4(input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_fast_math); } void compare_nvfp4_tensors(const std::string& name, @@ -376,13 +365,44 @@ void performTest(const ShapeRepresentation shape_rep, const double atol = 1.0E-6; const double rtol = 1.0E-6; - QuantizationConfigWrapper quant_config; - quant_config.set_use_fast_math(use_fast_math); - quant_config.set_stochastic_rounding(false); - quant_config.set_nvfp4_2d_quantization(false); + std::vector rowwise_scales_stride(num_tensors, 0); + std::vector colwise_scales_stride(num_tensors, 0); + std::vector rowwise_unpadded_blocks_X(num_tensors, 0); + std::vector colwise_unpadded_blocks_X(num_tensors, 0); + std::vector rowwise_scale_offsets(num_tensors, 0); + std::vector colwise_scale_offsets(num_tensors, 0); + + size_t rowwise_scales_num = 0; + size_t colwise_scales_num = 0; + + for (size_t t = 0; t < num_tensors; ++t) { + const size_t rows = first_dims[t]; + const size_t cols = last_dims[t]; + rowwise_unpadded_blocks_X[t] = divide_round_up(cols, static_cast(16)); + colwise_unpadded_blocks_X[t] = divide_round_up(rows, static_cast(16)); + + rowwise_scales_stride[t] = + round_up_to_nearest_multiple(rowwise_unpadded_blocks_X[t], static_cast(4)); + colwise_scales_stride[t] = + round_up_to_nearest_multiple(colwise_unpadded_blocks_X[t], static_cast(4)); + + rowwise_scale_offsets[t] = rowwise_scales_num; + colwise_scale_offsets[t] = colwise_scales_num; + + rowwise_scales_num += rows * rowwise_scales_stride[t]; + colwise_scales_num += cols * colwise_scales_stride[t]; + } + + std::vector out_data_rowwise_h(total_elts / 2); + std::vector out_data_colwise_h(total_elts / 2); + std::vector out_scales_rowwise_h(rowwise_scales_num); + std::vector out_scales_colwise_h(colwise_scales_num); + + std::vector out_data_rowwise_ref(total_elts / 2); + std::vector out_data_colwise_ref(total_elts / 2); + std::vector> out_scales_rowwise_ref(num_tensors); + std::vector> out_scales_colwise_ref(num_tensors); - // Grouped NVFP4 kernel is not available yet. - // Validate grouped metadata/configuration by quantizing each tensor independently. for (size_t t = 0; t < num_tensors; ++t) { const size_t rows = first_dims[t]; const size_t cols = last_dims[t]; @@ -390,76 +410,208 @@ void performTest(const ShapeRepresentation shape_rep, const size_t tensor_numel = rows * cols; ASSERT_EQ(offsets[t + 1] - offsets[t], tensor_numel); ASSERT_LE(tensor_offset + tensor_numel, total_elts); + ASSERT_EQ(tensor_numel % 2, 0U); - const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); - const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); - - const size_t unpadded_blocks_Y = scale_dims[0]; - const size_t unpadded_blocks_X = scale_dims[1]; - const size_t blocks_Y = scale_dims[2]; - const size_t blocks_X = scale_dims[3]; - const size_t scales_stride = blocks_X; - - const size_t unpadded_blocks_Y_t = scale_dims_t[0]; - const size_t unpadded_blocks_X_t = scale_dims_t[1]; - const size_t blocks_Y_t = scale_dims_t[2]; - const size_t blocks_X_t = scale_dims_t[3]; - const size_t scales_stride_t = blocks_X_t; - - std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); - std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); - std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); - std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); - - float amax = 0.0f; - for (size_t idx = 0; idx < tensor_numel; ++idx) { - amax = fmaxf(amax, fabsf(static_cast(grouped_input[tensor_offset + idx]))); - } - - Tensor input("input_tensor_" + std::to_string(t), std::vector{rows, cols}, itype); - std::copy(grouped_input.begin() + tensor_offset, - grouped_input.begin() + tensor_offset + tensor_numel, - input.rowwise_cpu_dptr()); - input.from_cpu(); - - Tensor output("output_tensor_" + std::to_string(t), std::vector{rows, cols}, otype, - true, true, NVTE_NVFP4_1D_SCALING); - output.set_scale(amax); + std::unique_ptr ref_output = + std::make_unique(tensor_numel / 2); + std::unique_ptr ref_output_t = + std::make_unique(tensor_numel / 2); + std::unique_ptr ref_scales = + std::make_unique(rows * rowwise_scales_stride[t]); + std::unique_ptr ref_scales_t = + std::make_unique(cols * colwise_scales_stride[t]); compute_ref(grouped_input.data() + tensor_offset, ref_output.get(), ref_output_t.get(), ref_scales.get(), ref_scales_t.get(), - output.scale(), + 0.0f, rows, cols, - scales_stride, - scales_stride_t, + rowwise_scales_stride[t], + colwise_scales_stride[t], use_fast_math); - nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + std::memcpy(out_data_rowwise_ref.data() + tensor_offset / 2, ref_output.get(), + (tensor_numel / 2) * sizeof(fp4e2m1x2)); + std::memcpy(out_data_colwise_ref.data() + tensor_offset / 2, ref_output_t.get(), + (tensor_numel / 2) * sizeof(fp4e2m1x2)); + + out_scales_rowwise_ref[t].assign(ref_scales.get(), + ref_scales.get() + rows * rowwise_scales_stride[t]); + out_scales_colwise_ref[t].assign(ref_scales_t.get(), + ref_scales_t.get() + cols * colwise_scales_stride[t]); + } - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + const size_t in_data_size = total_elts * sizeof(InputType); + const size_t out_data_size = (total_elts * typeToNumBits(otype)) / 8; + const size_t rowwise_scales_size = rowwise_scales_num * sizeof(fp8e4m3); + const size_t colwise_scales_size = colwise_scales_num * sizeof(fp8e4m3); - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), - static_cast(rows), static_cast(cols), atol, rtol, true, false); + std::vector first_dims_h(num_tensors, 0); + std::vector last_dims_h(num_tensors, 0); + std::vector offsets_h(num_tensors + 1, 0); + for (size_t t = 0; t < num_tensors; ++t) { + first_dims_h[t] = static_cast(first_dims[t]); + last_dims_h[t] = static_cast(last_dims[t]); + } + for (size_t t = 0; t < num_tensors + 1; ++t) { + offsets_h[t] = static_cast(offsets[t]); + } + + InputType* in_data_d = nullptr; + fp4e2m1* out_data_rowwise_d = nullptr; + fp4e2m1* out_data_colwise_d = nullptr; + fp8e4m3* out_scales_rowwise_d = nullptr; + fp8e4m3* out_scales_colwise_d = nullptr; + int64_t* first_dims_d = nullptr; + int64_t* last_dims_d = nullptr; + int64_t* offsets_d = nullptr; + + cudaMalloc((void**)&in_data_d, in_data_size); + cudaMalloc((void**)&out_data_rowwise_d, out_data_size); + cudaMalloc((void**)&out_data_colwise_d, out_data_size); + cudaMalloc((void**)&out_scales_rowwise_d, rowwise_scales_size); + cudaMalloc((void**)&out_scales_colwise_d, colwise_scales_size); + + cudaMalloc((void**)&first_dims_d, num_tensors * sizeof(int64_t)); + cudaMalloc((void**)&last_dims_d, num_tensors * sizeof(int64_t)); + cudaMalloc((void**)&offsets_d, (num_tensors + 1) * sizeof(int64_t)); + + cudaMemcpy(in_data_d, grouped_input.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(first_dims_d, first_dims_h.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(last_dims_d, last_dims_h.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(offsets_d, offsets_h.data(), (num_tensors + 1) * sizeof(int64_t), cudaMemcpyHostToDevice); + + cudaMemset(out_data_rowwise_d, 0, out_data_size); + cudaMemset(out_data_colwise_d, 0, out_data_size); + cudaMemset(out_scales_rowwise_d, 0, rowwise_scales_size); + cudaMemset(out_scales_colwise_d, 0, colwise_scales_size); + + NVTEShape logical_shape_ = nvte_make_shape(logical_shape.data(), logical_shape.size()); + + NVTEShape first_dims_shape_; + NVTEShape last_dims_shape_; + NVTEShape offsets_shape_; + first_dims_shape_.ndim = 1; + last_dims_shape_.ndim = 1; + offsets_shape_.ndim = 1; + first_dims_shape_.data[0] = num_tensors; + last_dims_shape_.data[0] = num_tensors; + offsets_shape_.data[0] = num_tensors + 1; + + NVTEGroupedTensor in_group_tensor = + nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor out_group_tensor = + nvte_create_grouped_tensor(NVTE_NVFP4_1D_SCALING, num_tensors, logical_shape_); + + NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &in_data_tensor, sizeof(in_data_tensor)); + + NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, NVTEDType::kNVTEFloat4E2M1, logical_shape_}; + NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, NVTEDType::kNVTEFloat4E2M1, logical_shape_}; + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &out_data_rowwise_tensor, sizeof(out_data_rowwise_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, + &out_data_colwise_tensor, sizeof(out_data_colwise_tensor)); + + std::vector rowwise_scales_shape = {rowwise_scales_num}; + std::vector colwise_scales_shape = {colwise_scales_num}; + NVTEShape rowwise_scales_shape_ = + nvte_make_shape(rowwise_scales_shape.data(), rowwise_scales_shape.size()); + NVTEShape colwise_scales_shape_ = + nvte_make_shape(colwise_scales_shape.data(), colwise_scales_shape.size()); + NVTEBasicTensor out_scales_rowwise_tensor = { + out_scales_rowwise_d, NVTEDType::kNVTEFloat8E4M3, rowwise_scales_shape_}; + NVTEBasicTensor out_scales_colwise_tensor = { + out_scales_colwise_d, NVTEDType::kNVTEFloat8E4M3, colwise_scales_shape_}; + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, + &out_scales_rowwise_tensor, sizeof(out_scales_rowwise_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, + &out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor)); + + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + } + + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + } + + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + } + + nvte_group_quantize(in_group_tensor, out_group_tensor, 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, + cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, colwise_scales_size, + cudaMemcpyDeviceToHost); + + for (size_t t = 0; t < num_tensors; ++t) { + const size_t rows = first_dims[t]; + const size_t cols = last_dims[t]; + const size_t tensor_offset = offsets[t]; + + const fp4e2m1* test_output = out_data_rowwise_h.data() + tensor_offset / 2; + const fp4e2m1* ref_output = out_data_rowwise_ref.data() + tensor_offset / 2; + const fp4e2m1* test_output_t = out_data_colwise_h.data() + tensor_offset / 2; + const fp4e2m1* ref_output_t = out_data_colwise_ref.data() + tensor_offset / 2; + + compare_nvfp4_tensors("output_" + std::to_string(t), test_output, ref_output, + static_cast(rows), static_cast(cols), atol, rtol); + compare_nvfp4_tensors("output_t_" + std::to_string(t), test_output_t, ref_output_t, + static_cast(cols), static_cast(rows), atol, rtol); size_t scale_mismatches_num = 0; - compare_scaling_factors("scales_" + std::to_string(t), - output.rowwise_cpu_scale_inv_ptr(), - ref_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - scale_mismatches_num); - - compare_scaling_factors("scales_t_" + std::to_string(t), - output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, - scale_mismatches_num); + compare_scaling_factors( + "scales_" + std::to_string(t), + out_scales_rowwise_h.data() + rowwise_scale_offsets[t], + out_scales_rowwise_ref[t].data(), + rows, rowwise_unpadded_blocks_X[t], rowwise_scales_stride[t], scale_mismatches_num); + + compare_scaling_factors( + "scales_t_" + std::to_string(t), + out_scales_colwise_h.data() + colwise_scale_offsets[t], + out_scales_colwise_ref[t].data(), + cols, colwise_unpadded_blocks_X[t], colwise_scales_stride[t], scale_mismatches_num); } + + nvte_destroy_grouped_tensor(in_group_tensor); + nvte_destroy_grouped_tensor(out_group_tensor); + + cudaFree(in_data_d); + cudaFree(out_data_rowwise_d); + cudaFree(out_data_colwise_d); + cudaFree(out_scales_rowwise_d); + cudaFree(out_scales_colwise_d); + cudaFree(first_dims_d); + cudaFree(last_dims_d); + cudaFree(offsets_d); } // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a30a29f6be..0d5a506209 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -415,10 +415,13 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // } case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "2D quantization is not supported for group quantize."); - NVTE_CHECK(input_tensor->dtype() == DType::kBFloat16, - "Optimized grouped NVFP4 kernel supports only BF16 input."); + + const bool is_bf16_input_type = input_tensor->dtype() == DType::kBFloat16; + NVTE_CHECK(is_bf16_input_type, "Optimized grouped NVFP4 kernel supports only BF16 input."); + + const bool is_2D_quantization = quant_config_cpp.nvfp4_2d_quantization; + NVTE_CHECK(!is_2D_quantization, "2D quantization is not supported for group quantize."); + nvfp4::group_quantize_transpose(input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); break; @@ -467,10 +470,12 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK((!IS_DBIAS && !IS_DACT), "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "2D quantization is not supported for group quantize."); - NVTE_CHECK(grad_tensor->dtype() == DType::kBFloat16, - "Optimized grouped NVFP4 kernel supports only BF16 input."); + const bool is_bf16_input_type = grad_tensor->dtype() == DType::kBFloat16; + NVTE_CHECK(is_bf16_input_type, "Optimized grouped NVFP4 kernel supports only BF16 input."); + + const bool is_2D_quantization = quant_config_cpp.nvfp4_2d_quantization; + NVTE_CHECK(!is_2D_quantization, "2D quantization is not supported for group quantize."); + nvfp4::group_quantize_transpose(grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); break; diff --git a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh index dd3744b6b4..153335e6d4 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh @@ -38,18 +38,20 @@ using namespace dispatch::common; #if FP4_TYPE_SUPPORTED struct TunableConfig { - static constexpr int CHUNK_DIM_Y = 128; - static constexpr int CHUNK_DIM_X = 128; + static constexpr size_t CHUNK_DIM_Y = 128; + static constexpr size_t CHUNK_DIM_X = 128; static constexpr int PREFETCH_STAGES = 1; static constexpr bool PERSISTENT = true; - static constexpr int STATIC_PERSISTENT_BLOCKS_PER_SM = 4; + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; -constexpr int CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; -constexpr int CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr size_t CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr size_t CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; constexpr int PREFETCH_STAGES = TunableConfig::PREFETCH_STAGES; constexpr bool PERSISTENT = TunableConfig::PERSISTENT; -constexpr int STATIC_PERSISTENT_BLOCKS_PER_SM = TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; +constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; + +constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; static_assert(!PERSISTENT || (STATIC_PERSISTENT_BLOCKS_PER_SM > 0), "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); @@ -446,6 +448,30 @@ __device__ __forceinline__ size_t get_tensor_base_offset( return static_cast(offsets_ptr[tensor_id]); } +__device__ __forceinline__ size_t get_nvfp4_scale_stride(const size_t block_scaled_dim) { + return DIVUP_TO_MULTIPLE(DIVUP(block_scaled_dim, static_cast(SCALE_DIM)), + static_cast(4)); +} + +__device__ __forceinline__ size_t get_grouped_scale_base_offset( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const size_t last_logical_dim, const size_t num_tensors, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise) { + size_t scale_base = 0; + for (size_t t = 0; t < tensor_id; ++t) { + const size_t rows = + get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(t, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t scale_rows = rowwise ? rows : cols; + const size_t stride_dim = rowwise ? cols : rows; + const size_t scale_stride = get_nvfp4_scale_stride(stride_dim); + scale_base += scale_rows * scale_stride; + } + return scale_base; +} + struct JobDescriptor { size_t block_id = 0; size_t block_global_offset = 0; @@ -471,9 +497,8 @@ __device__ __forceinline__ JobDescriptor decode_job( JobDescriptor job{}; job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); job.block_global_offset = use_single_work_grid - ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (job.block_id * CHUNK_DIM_Y * CHUNK_DIM_X); + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, static_cast(ctaid_Y), first_logical_dim, last_logical_dim, offsets_ptr); @@ -492,7 +517,6 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, return is_valid; } - const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); if (job.block_global_offset >= tensor_end_offset) { return false; @@ -519,12 +543,10 @@ __device__ __forceinline__ BlockDescriptor decode_block( block.block_id_Y = static_cast(ctaid_Y) - job.tensor_id * blocks_Y_per_tensor; } else { const size_t tensor_base_row = block.tensor_base / job.cols; - block.block_id_Y = - static_cast(ctaid_Y) - tensor_base_row / static_cast(CHUNK_DIM_Y); + block.block_id_Y = static_cast(ctaid_Y) - tensor_base_row / CHUNK_DIM_Y; } } else { - const size_t block_id_in_current_tensor = - job.block_id - block.tensor_base / (CHUNK_DIM_Y * CHUNK_DIM_X); + const size_t block_id_in_current_tensor = job.block_id - block.tensor_base / ELTS_PER_CHUNK; block.block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; block.block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; } @@ -822,13 +844,20 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu const size_t scales_block_offset_Y_tr = current_block.block_id_X * CHUNK_DIM_X; const size_t scales_block_offset_X_tr = current_block.block_id_Y * SCALES_PER_CHUNK_Y; - nvfp4_scale_t *const scales_rowwise = scales_ptr + current_block.tensor_base / SCALE_DIM; + const size_t scale_stride = get_nvfp4_scale_stride(cols); + const size_t scale_stride_t = get_nvfp4_scale_stride(rows); + + const size_t rowwise_scale_base = + get_grouped_scale_base_offset(current_job.tensor_id, shape_rep, first_logical_dim, + last_logical_dim, num_tensors, first_dims_ptr, last_dims_ptr, + true); + const size_t colwise_scale_base = + get_grouped_scale_base_offset(current_job.tensor_id, shape_rep, first_logical_dim, + last_logical_dim, num_tensors, first_dims_ptr, last_dims_ptr, + false); + nvfp4_scale_t *const scales_rowwise = scales_ptr + rowwise_scale_base; nvfp4_scale_t *const scales_colwise = - RETURN_TRANSPOSE ? (scales_t_ptr + current_block.tensor_base / SCALE_DIM) : nullptr; - const size_t scale_stride = - DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(SCALE_DIM)), static_cast(4)); - const size_t scale_stride_t = - DIVUP_TO_MULTIPLE(DIVUP(rows, static_cast(SCALE_DIM)), static_cast(4)); + RETURN_TRANSPOSE ? (scales_t_ptr + colwise_scale_base) : nullptr; const CUtensorMap &tensor_map_input = g_tensor_maps_input[current_job.tensor_id]; const CUtensorMap &tensor_map_output = g_tensor_maps_output[current_job.tensor_id]; @@ -1065,27 +1094,22 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n NVTE_CHECK(first_logical_dim % 128 == 0, "First logical dimension of a grouped tensor must be divisible by 128."); } - NVTE_CHECK(first_logical_dim % 32 == 0, - "Number of tensor rows must be a multiple of 32."); - NVTE_CHECK(last_logical_dim % 32 == 0, - "Number of tensor cols must be a multiple of 32."); size_t work_blocks_X = 0; size_t work_blocks_Y = 0; if (use_single_work_grid) { - work_blocks_Y = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); - work_blocks_X = DIVUP(last_logical_dim, static_cast(CHUNK_DIM_X)); + work_blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + work_blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); } else { work_blocks_Y = 1; - work_blocks_X = DIVUP(elts_total, static_cast(CHUNK_DIM_Y * CHUNK_DIM_X)); + work_blocks_X = DIVUP(elts_total, ELTS_PER_CHUNK); } size_t launch_blocks_X = work_blocks_X; size_t launch_blocks_Y = work_blocks_Y; if constexpr (PERSISTENT) { const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); - const size_t static_grid_size = - sm_num * static_cast(STATIC_PERSISTENT_BLOCKS_PER_SM); + const size_t static_grid_size = sm_num * STATIC_PERSISTENT_BLOCKS_PER_SM; NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); launch_blocks_X = static_grid_size; launch_blocks_Y = 1; @@ -1122,13 +1146,15 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n alignas(64) CUtensorMap tensor_map_output{}; alignas(64) CUtensorMap tensor_map_output_transpose{}; - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, sizeof(IType) * 8); - create_2D_tensor_map(tensor_map_output, output->data, first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, 4); + const size_t dummy_first_logical_dim = 32; + const size_t dummy_last_logical_dim = 32; + create_2D_tensor_map(tensor_map_input, input->data, dummy_first_logical_dim, dummy_last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, dummy_last_logical_dim, 0, sizeof(IType) * 8); + create_2D_tensor_map(tensor_map_output, output->data, dummy_first_logical_dim, dummy_last_logical_dim, + BUFF_DIM_Y, BUFF_DIM_X, dummy_last_logical_dim, 0, 4); if (return_transpose) { - create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, last_logical_dim, - first_logical_dim, BUFF_DIM_X, BUFF_DIM_Y, first_logical_dim, 0, 4); + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, dummy_last_logical_dim, + dummy_first_logical_dim, BUFF_DIM_X, BUFF_DIM_Y, dummy_first_logical_dim, 0, 4); } constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; From 46fdb930174d75b762029acf0b72c30ee4fe6fdf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:27:52 +0000 Subject: [PATCH 31/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/dispatch/quantize.cuh | 12 +- ...roup_quantize_transpose_nvfp4_tuned_1D.cuh | 135 +++++++++--------- 2 files changed, 76 insertions(+), 71 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 0d5a506209..f60ba007cf 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,9 +21,9 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" -#include "../nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh" #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" +#include "../nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { namespace dispatch { @@ -422,8 +422,8 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor const bool is_2D_quantization = quant_config_cpp.nvfp4_2d_quantization; NVTE_CHECK(!is_2D_quantization, "2D quantization is not supported for group quantize."); - nvfp4::group_quantize_transpose(input_tensor, noop_tensor, output_tensor, - &quant_config_cpp, stream); + nvfp4::group_quantize_transpose(input_tensor, noop_tensor, output_tensor, &quant_config_cpp, + stream); break; } default: @@ -475,9 +475,9 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe const bool is_2D_quantization = quant_config_cpp.nvfp4_2d_quantization; NVTE_CHECK(!is_2D_quantization, "2D quantization is not supported for group quantize."); - - nvfp4::group_quantize_transpose(grad_tensor, noop_tensor, output_tensor, - &quant_config_cpp, stream); + + nvfp4::group_quantize_transpose(grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, + stream); break; } default: diff --git a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh index 153335e6d4..25df7b4690 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh @@ -437,10 +437,10 @@ __device__ __forceinline__ size_t get_tensor_cols_num( return cols_num; } -__device__ __forceinline__ size_t get_tensor_base_offset( - const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, - const size_t last_logical_dim, const size_t num_tensors, - const int64_t *const __restrict__ offsets_ptr) { +__device__ __forceinline__ size_t +get_tensor_base_offset(const size_t tensor_id, const ShapeRepresentation shape_rep, + const size_t first_logical_dim, const size_t last_logical_dim, + const size_t num_tensors, const int64_t *const __restrict__ offsets_ptr) { if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t rows_per_tensor = first_logical_dim / num_tensors; return tensor_id * rows_per_tensor * last_logical_dim; @@ -497,11 +497,11 @@ __device__ __forceinline__ JobDescriptor decode_job( JobDescriptor job{}; job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); job.block_global_offset = use_single_work_grid - ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) - : (job.block_id * ELTS_PER_CHUNK); - job.tensor_id = - get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, static_cast(ctaid_Y), - first_logical_dim, last_logical_dim, offsets_ptr); + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); + job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, + static_cast(ctaid_Y), first_logical_dim, + last_logical_dim, offsets_ptr); job.rows = get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr); @@ -538,8 +538,7 @@ __device__ __forceinline__ BlockDescriptor decode_block( block.block_id_X = static_cast(ctaid_X); if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t rows_per_tensor = first_logical_dim / num_tensors; - const size_t blocks_Y_per_tensor = - DIVUP(rows_per_tensor, static_cast(CHUNK_DIM_Y)); + const size_t blocks_Y_per_tensor = DIVUP(rows_per_tensor, static_cast(CHUNK_DIM_Y)); block.block_id_Y = static_cast(ctaid_Y) - job.tensor_id * blocks_Y_per_tensor; } else { const size_t tensor_base_row = block.tensor_base / job.cols; @@ -639,10 +638,10 @@ __global__ void update_tma_descriptors( } if (rowwise) { - const uintptr_t global_data_ptr = get_pointer_with_offset_bits( - reinterpret_cast(output_data_ptr), offset_elts, 4); - modify_base_tensor_map(base_tensor_map_output, &g_tensor_maps_output[tensor_id], global_data_ptr, - rows, cols, 4); + const uintptr_t global_data_ptr = + get_pointer_with_offset_bits(reinterpret_cast(output_data_ptr), offset_elts, 4); + modify_base_tensor_map(base_tensor_map_output, &g_tensor_maps_output[tensor_id], + global_data_ptr, rows, cols, 4); } if (colwise) { @@ -682,8 +681,8 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu return; } - const size_t launch_block_id = - static_cast(blockIdx.y) * static_cast(gridDim.x) + static_cast(blockIdx.x); + const size_t launch_block_id = static_cast(blockIdx.y) * static_cast(gridDim.x) + + static_cast(blockIdx.x); const size_t rng_sequence = threadIdx.x + launch_block_id * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; @@ -693,9 +692,8 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu int rnd_idx = 0; const bool leading_thread = (threadIdx.x == 0); - const bool use_single_work_grid = - (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || - shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + const bool use_single_work_grid = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; @@ -775,10 +773,9 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu bool has_prefetched_current_job = true; { - const JobDescriptor first_job = decode_job(shape_rep, use_single_work_grid, num_tensors, - first_logical_dim, last_logical_dim, work_blocks_X, - ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, - last_dims_ptr); + const JobDescriptor first_job = decode_job( + shape_rep, use_single_work_grid, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); if (!is_job_valid(first_job, shape_rep, total_work_blocks, offsets_ptr)) { return; } @@ -809,10 +806,9 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu } while (!job_finished) { - const JobDescriptor current_job = decode_job(shape_rep, use_single_work_grid, num_tensors, - first_logical_dim, last_logical_dim, work_blocks_X, - ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, - last_dims_ptr); + const JobDescriptor current_job = decode_job( + shape_rep, use_single_work_grid, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); const bool current_job_is_valid = is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); if (!current_job_is_valid) { @@ -847,14 +843,12 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu const size_t scale_stride = get_nvfp4_scale_stride(cols); const size_t scale_stride_t = get_nvfp4_scale_stride(rows); - const size_t rowwise_scale_base = - get_grouped_scale_base_offset(current_job.tensor_id, shape_rep, first_logical_dim, - last_logical_dim, num_tensors, first_dims_ptr, last_dims_ptr, - true); - const size_t colwise_scale_base = - get_grouped_scale_base_offset(current_job.tensor_id, shape_rep, first_logical_dim, - last_logical_dim, num_tensors, first_dims_ptr, last_dims_ptr, - false); + const size_t rowwise_scale_base = get_grouped_scale_base_offset( + current_job.tensor_id, shape_rep, first_logical_dim, last_logical_dim, num_tensors, + first_dims_ptr, last_dims_ptr, true); + const size_t colwise_scale_base = get_grouped_scale_base_offset( + current_job.tensor_id, shape_rep, first_logical_dim, last_logical_dim, num_tensors, + first_dims_ptr, last_dims_ptr, false); nvfp4_scale_t *const scales_rowwise = scales_ptr + rowwise_scale_base; nvfp4_scale_t *const scales_colwise = RETURN_TRANSPOSE ? (scales_t_ptr + colwise_scale_base) : nullptr; @@ -905,22 +899,20 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu } } - if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && - !job_finished) { + if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { prefetch_job = decode_job(shape_rep, use_single_work_grid, num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); allow_next_job_prefetch = is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); if (allow_next_job_prefetch) { - prefetch_block = decode_block(prefetch_job, shape_rep, use_single_work_grid, - first_logical_dim, last_logical_dim, num_tensors, ctaid_X, - ctaid_Y, offsets_ptr); + prefetch_block = + decode_block(prefetch_job, shape_rep, use_single_work_grid, first_logical_dim, + last_logical_dim, num_tensors, ctaid_X, ctaid_Y, offsets_ptr); } } - if ((stage < STAGES - PREFETCH_STAGES) || - (allow_next_job_prefetch && !job_finished)) { + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { const int next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; const int next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; @@ -1011,8 +1003,10 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu for (size_t row_tr = threadIdx.x; row_tr < CHUNK_DIM_X; row_tr += THREADS_NUM) { const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; if (row_tr_global < cols) { - ColwiseScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); - const size_t scale_idx_global = row_tr_global * scale_stride_t + scales_block_offset_X_tr; + ColwiseScalesVec &scales_vec = + *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; scales_vec.store_to_elts(&scales_colwise[scale_idx_global], 0, colwise_count); } } @@ -1079,8 +1073,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; } - const bool use_single_work_grid = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS - || shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + const bool use_single_work_grid = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); const size_t first_logical_dim = input->logical_shape.data[0]; const size_t last_logical_dim = input->logical_shape.data[1]; @@ -1148,22 +1142,30 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n const size_t dummy_first_logical_dim = 32; const size_t dummy_last_logical_dim = 32; - create_2D_tensor_map(tensor_map_input, input->data, dummy_first_logical_dim, dummy_last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, dummy_last_logical_dim, 0, sizeof(IType) * 8); - create_2D_tensor_map(tensor_map_output, output->data, dummy_first_logical_dim, dummy_last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, dummy_last_logical_dim, 0, 4); + create_2D_tensor_map(tensor_map_input, input->data, dummy_first_logical_dim, + dummy_last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, dummy_last_logical_dim, 0, + sizeof(IType) * 8); + create_2D_tensor_map(tensor_map_output, output->data, dummy_first_logical_dim, + dummy_last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, dummy_last_logical_dim, 0, + 4); if (return_transpose) { - create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, dummy_last_logical_dim, - dummy_first_logical_dim, BUFF_DIM_X, BUFF_DIM_Y, dummy_first_logical_dim, 0, 4); + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, + dummy_last_logical_dim, dummy_first_logical_dim, BUFF_DIM_X, BUFF_DIM_Y, + dummy_first_logical_dim, 0, 4); } constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; - constexpr int buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_aligned_out_t = DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales = DIVUP_TO_MULTIPLE(CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); - constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE(CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( + CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); const int in_mem = buff_size_aligned_in; const int out_data_mem = buff_size_aligned_out; @@ -1171,7 +1173,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n const int out_scales_mem = buff_size_scales; const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; const int out_mem = out_data_mem + out_data_transpose_mem; - const int dshmem_size = in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + const int dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; const IType *const input_dptr = reinterpret_cast(input->data.dptr); const void *const output_dptr = output->data.dptr; @@ -1183,12 +1186,14 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n first_dims_ptr, last_dims_ptr, true, return_transpose); NVTE_CHECK_CUDA(cudaGetLastError()); - TRANSFORMER_ENGINE_SWITCH_CONDITION(use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, - { - auto kernel = group_quantize_transpose_nvfp4_tuned_1D_kernel - ; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + group_quantize_transpose_nvfp4_tuned_1D_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));