diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 6557c83773..e54ceebaa3 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,8 +58,7 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise, - const bool is_single_tensor) + const size_t scales_stride_colwise) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -169,10 +168,8 @@ void compute_ref(const ProcessingMethod processing_method, } } - if (is_single_tensor) { - for (size_t j = 0; j < cols; ++j) { - output_dbias[j] = static_cast(output_dbias_fp32[j]); - } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); } } @@ -250,12 +247,16 @@ void performTest(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; + const bool compute_dbias = (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT); + const size_t rows = logical_shape_vec[0]; const size_t cols = logical_shape_vec[1]; size_t elts_num = 0; size_t rowwise_sfs_num = 0; size_t colwise_sfs_num = 0; + size_t sum_of_last_dims = 0; std::vector rowwise_scales_first_dim(num_tensors, 0); std::vector rowwise_scales_last_dim(num_tensors, 0); @@ -263,6 +264,7 @@ void performTest(const ProcessingMethod processing_method, std::vector colwise_scales_first_dim(num_tensors, 0); std::vector colwise_scales_last_dim(num_tensors, 0); std::vector colwise_scales_offset(num_tensors + 1, 0); + std::vector dbias_offsets(num_tensors + 1, 0); for (size_t t = 0; t < num_tensors; ++t) { const size_t M = first_dims_h[t]; @@ -285,13 +287,13 @@ void performTest(const ProcessingMethod processing_method, rowwise_sfs_num += rowwise_sfs; colwise_sfs_num += colwise_sfs; + sum_of_last_dims += K; rowwise_scales_offset[t+1] = rowwise_sfs_num; colwise_scales_offset[t+1] = colwise_sfs_num; + dbias_offsets[t+1] = sum_of_last_dims; } - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); - std::vector scales_rowwise_shape = {rowwise_sfs_num}; std::vector scales_colwise_shape = {colwise_sfs_num}; @@ -311,7 +313,7 @@ void performTest(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? colwise_sfs_num : 0); - std::vector ref_output_dbias(is_single_tensor ? cols : 0); + std::vector ref_output_dbias(sum_of_last_dims, static_cast(0.0f)); for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); @@ -336,6 +338,7 @@ void performTest(const ProcessingMethod processing_method, const size_t in_data_size = elts_num * sizeof(InputType); const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t dbias_data_size = sum_of_last_dims * sizeof(InputType); const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0); const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0); @@ -343,15 +346,16 @@ void performTest(const ProcessingMethod processing_method, const size_t last_dims_size = num_tensors * sizeof(size_t); const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); - InputType* grad_data_d; - InputType* in_data_d; - OutputType* out_data_rowwise_d; - OutputType* out_data_colwise_d; - fp8e8m0* out_scales_rowwise_d; - fp8e8m0* out_scales_colwise_d; - size_t* first_dims_d; - size_t* last_dims_d; - size_t* offsets_d; + InputType* grad_data_d = nullptr; + InputType* in_data_d = nullptr; + InputType* dbias_out_data_d = nullptr; + OutputType* out_data_rowwise_d = nullptr; + OutputType* out_data_colwise_d = nullptr; + fp8e8m0* out_scales_rowwise_d = nullptr; + fp8e8m0* out_scales_colwise_d = nullptr; + size_t* first_dims_d = nullptr; + size_t* last_dims_d = nullptr; + size_t* offsets_d = nullptr; cudaMalloc((void**)&grad_data_d, in_data_size); cudaMalloc((void**)&in_data_d, in_data_size); @@ -367,6 +371,10 @@ void performTest(const ProcessingMethod processing_method, NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + std::vector dbias_logical_shape_vec = {num_tensors, cols}; + NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), + dbias_logical_shape_vec.size()); + NVTEShape first_dims_shape_; NVTEShape last_dims_shape_; NVTEShape offsets_shape_; @@ -382,6 +390,7 @@ void performTest(const ProcessingMethod processing_method, NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor output_dbias_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, dbias_logical_shape_); NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; @@ -453,52 +462,41 @@ void performTest(const ProcessingMethod processing_method, &out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor)); } - Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + if (compute_dbias) { + cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); + cudaMemset(dbias_out_data_d, 0, dbias_data_size); + NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; + nvte_set_grouped_tensor_param(output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &output_dbias_data_tensor, sizeof(output_dbias_data_tensor)); + } // Reference (CPU) - if (is_single_tensor) { - - const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32); - const size_t unpadded_colwise_blocks_X = cols; + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; - const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4); - const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128); + const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; + const size_t scales_stride_colwise = colwise_scales_last_dim[t]; + const size_t data_offset = offsets_h[t]; + const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; + const size_t colwise_sfs_offset = colwise_scales_offset[t]; + const size_t dbias_offset = dbias_offsets[t]; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; + InputType* const ref_output_dbias_ptr = ref_output_dbias.data() + dbias_offset; compute_ref( - processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(), - out_data_rowwise_ref.data(), out_data_colwise_ref.data(), - out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), - ref_output_dbias.data(), rows, cols, + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias_ptr, M, K, scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } else { - for (size_t t = 0; t < num_tensors; ++t) { - const size_t M = first_dims_h[t]; - const size_t K = last_dims_h[t]; - - const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; - const size_t scales_stride_colwise = colwise_scales_last_dim[t]; - const size_t data_offset = offsets_h[t]; - const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; - const size_t colwise_sfs_offset = colwise_scales_offset[t]; - - const InputType* const grad_ptr = grad_data.data() + data_offset; - const InputType* const in_ptr = in_data.data() + data_offset; - OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; - OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; - fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; - fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; - - compute_ref( - processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, - out_data_rowwise_ptr, out_data_colwise_ptr, - out_scales_rowwise_ptr, out_scales_colwise_ptr, - ref_output_dbias.data(), M, K, - scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } + scales_stride_colwise); } // GPU @@ -509,9 +507,9 @@ void performTest(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS: { - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_DBIAS_DACT: { @@ -522,10 +520,10 @@ void performTest(const ProcessingMethod processing_method, else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -556,6 +554,11 @@ void performTest(const ProcessingMethod processing_method, const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; + // Compare only allocated contiguous output range. + // In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back(). + const size_t compare_rows = 1; + const size_t compare_cols = elts_num; + if (rowwise) { cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost); @@ -568,7 +571,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), - out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + out_data_rowwise_h.data(), compare_rows, compare_cols, + true, mismatches_elts); } if (colwise) { @@ -583,12 +587,14 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), - out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + out_data_colwise_h.data(), compare_rows, compare_cols, + false, mismatches_elts); } - if (processing_method == ProcessingMethod::CAST_DBIAS - || processing_method == ProcessingMethod::CAST_DBIAS_DACT) - { + if (compute_dbias) { + Tensor output_dbias("output_dbias", std::vector{ sum_of_last_dims }, itype); + cudaMemcpy(output_dbias.rowwise_dptr(), dbias_out_data_d, dbias_data_size, cudaMemcpyDeviceToDevice); + auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -601,6 +607,7 @@ void performTest(const ProcessingMethod processing_method, cudaFree(grad_data_d); cudaFree(in_data_d); + cudaFree(dbias_out_data_d); cudaFree(first_dims_d); cudaFree(last_dims_d); cudaFree(offsets_d); @@ -648,8 +655,10 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 2, 384,160, 128,256}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, @@ -714,26 +723,31 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { } } offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; - // Skips tests if tensor shape is not as required by the kernel - if (first_dims[t] % 128 != 0) { + // Skip tests when the tensor shape is incompatible with the kernel. + // The TMA engine requires strides to be 16-byte aligned. + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 16 != 0)) { GTEST_SKIP(); } - if (!is_single_tensor && (last_dims[t] % 128 != 0)) { + // If a grouped tensor has a varying last dimension, it must be a multiple of 128. + // Otherwise, computing the grid size adds runtime overhead in the non-persistent kernel, + // since the relevant tensor metadata resides in device memory. + constexpr size_t CHUNK_DIM_X = 128; + if (!is_single_tensor && (last_dims[t] % CHUNK_DIM_X != 0)) { GTEST_SKIP(); } } - // Skips DBias tests if last dimension of tensors variates + // Skip dBias tests when tensors in the group have different last dimensions. if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) && !is_single_tensor) { GTEST_SKIP(); } - // Skips non Act tests if the Activation type is not an identity + // Skip non-activation tests when the activation type is not Identity. if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) && activation != ActivationKind::Identity) { GTEST_SKIP(); } - // Skips Act tests if the Activation is an identity + // Skip activation tests when the activation type is Identity. if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT || processing_method == ProcessingMethod::CAST_DACT || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index d209ea8d47..ea864813bf 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -32,7 +32,7 @@ void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dqgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index b6f758caf6..fc9122b7ec 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -32,7 +32,7 @@ void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_drelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_drelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsrelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 77d5b6867f..12478af4cf 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -32,7 +32,7 @@ void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsilu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 57404ae8a5..4f9ddb4fc5 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d } void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 0997b01f7e..9c16666db0 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -22,6 +22,14 @@ namespace transformer_engine { namespace dispatch { namespace common { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); @@ -78,6 +86,57 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } stg_vec.store_to(thread_out_base); } + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, OType *const dbias_output, + const float *dbias_partial, const size_t chunk_dim_Y) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t tensor_id = blockIdx.y; + const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (first_logical_dim / num_tensors) + : static_cast(first_dims_ptr[tensor_id]); + + const size_t rows = tensor_rows / chunk_dim_Y; + const size_t cols = last_logical_dim; + + const size_t dbias_in_offset_Y = + (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec; + OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} } // namespace kernel template @@ -96,6 +155,32 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const data_tensor_offsets_ptr, + const int64_t *const data_tensor_first_dims_ptr, + const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias, + const float *workspace_ptr, const size_t chunk_dim_Y, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(last_logical_dim % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t blocks_X = DIVUP(last_logical_dim, THREADS_PER_BLOCK * reduce_dbias_nvec); + const size_t blocks_Y = num_tensors; + const dim3 grid(blocks_X, blocks_Y); + + group_reduce_dbias_kernel<<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr, + data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace common } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 98a3fb8cba..f7823b4c58 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -382,13 +382,13 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); const NVTEGroupedTensor activation = nullptr; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config @@ -419,8 +419,9 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor template void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); @@ -428,7 +429,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 6447fc4542..6e9bd3dc5e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -17,6 +17,7 @@ #include #include "../../common.h" +#include "../../util/cuda_runtime.h" #include "../../util/math.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" @@ -28,29 +29,41 @@ namespace dispatch { namespace mxfp8 { namespace group_quantize_kernel { +using namespace dispatch::common; + constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; __device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 +struct TunableConfig { + static constexpr size_t CHUNK_DIM_Y = 128; + static constexpr size_t CHUNK_DIM_X = 128; + static constexpr size_t THREADS_PER_CHUNK = 128; + static constexpr size_t PREFETCH_STAGES = 1; + // true -> static persistent grid-stride scheduler + // false -> non-persistent one-job-per-CTA execution + static constexpr bool PERSISTENT = true; + // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; +constexpr bool PERSISTENT = TunableConfig::PERSISTENT; +static_assert(!PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0), + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); + constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t BUFFS_NUM = 2; +constexpr size_t PREFETCH_STAGES = TunableConfig::PREFETCH_STAGES; +constexpr size_t BUFFS_NUM = PREFETCH_STAGES + 1; constexpr size_t PACK_SIZE = 4; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr size_t CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK = TunableConfig::THREADS_PER_CHUNK; constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; @@ -111,6 +124,9 @@ __device__ __forceinline__ size_t get_tensor_rows_num( rows_num = static_cast(first_dims_ptr[tensor_id]); break; } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } return rows_num; } @@ -126,11 +142,95 @@ __device__ __forceinline__ size_t get_tensor_cols_num( case ShapeRepresentation::VARYING_LAST_DIM: case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); + if (cols_num % 128 != 0) { + NVTE_DEVICE_ERROR( + "For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); + } break; } return cols_num; } +// Logical work-item decoded from CTA coordinates. +struct JobDescriptor { + size_t block_id = 0; + size_t block_global_offset = 0; + size_t tensor_id = 0; + size_t rows = 0; + size_t cols = 0; +}; + +// Tensor-local coordinates for a work-item. +struct BlockDescriptor { + size_t tensor_base = 0; + size_t block_id_in_current_tensor = 0; + size_t block_id_Y = 0; + size_t block_id_X = 0; + size_t block_offset_Y = 0; + size_t block_offset_X = 0; +}; + +__device__ __forceinline__ JobDescriptor decode_job( + const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, + const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr) { + JobDescriptor job{}; + job.block_id = ctaid_Y * work_blocks_X + ctaid_X; + job.block_global_offset = is_single_tensor + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); + job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + job.rows = + get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + return job; +} + +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, + const ShapeRepresentation shape_rep, + const size_t total_work_blocks, + const int64_t *const __restrict__ offsets_ptr) { + bool is_valid = (job.block_id < total_work_blocks) && (job.rows != 0) && (job.cols != 0); + if (!is_valid || shape_rep == SAME_BOTH_DIMS) { + return is_valid; + } + + const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); + if (job.block_global_offset >= tensor_end_offset) { + return false; + } + + const size_t tensor_offset_from_start = job.block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / job.cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % job.cols; + if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) { + return false; + } + + return true; +} + +__device__ __forceinline__ BlockDescriptor +decode_block(const JobDescriptor &job, const bool is_single_tensor, + const int64_t *const __restrict__ offsets_ptr) { + BlockDescriptor block{}; + block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); + const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); + block.block_id_in_current_tensor = + is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK); + block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor; + block.block_id_X = block.block_id_in_current_tensor % blocks_X_num_in_current_tensor; + block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y; + block.block_offset_X = block.block_id_X * CHUNK_DIM_X; + return block; +} + // Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, CUtensorMap *global_tensor_map, @@ -144,7 +244,7 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te if constexpr (is_blackwell) { const size_t global_stride_bytes = global_dim_X * data_type_size_bytes; if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) { - NVTE_DEVICE_ERROR("Shape not supported, as data stride must be 16B aligned."); + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be 16B aligned."); } if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); @@ -230,125 +330,322 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +// Issue TMA global->shared transfer for one stage of input (and optional activation input). +template +__device__ __forceinline__ void prefetch_input_stage( + IType *in_sh, IType *act_in_sh, const CUtensorMap &tensor_map_input, + const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, + const size_t global_offset_Y, const size_t buff_offset, const size_t shmem_buff_size, + uint64_t *barrier, const bool leading_thread) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buff_offset]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + barrier); + if constexpr (IS_DACT) { + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[buff_offset]), + reinterpret_cast(&tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); + } + } +} + +// Issue TMA shared->global transfer for one stage of outputs. +template +__device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, + OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, + const CUtensorMap &tensor_map_output_colwise, + const int global_offset_X, + const int global_offset_Y, const int buff_offset, + const bool leading_thread) { + if (!leading_thread) { + return; + } + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + ptx::cp_async_bulk_commit_group(); +} + template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( - const __grid_constant__ CUtensorMap tensor_map_input_static, - const __grid_constant__ CUtensorMap tensor_map_act_input_static, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, - const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, - e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, - float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + bool WITH_GEMM_SWIZZLED_SCALES> +__device__ __forceinline__ float process_colwise_stage( + const size_t buff, const int stage, const size_t tid_X_colwise, + const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, + const size_t scale_stride_colwise, const size_t tensor_base_for_scales, const size_t rows, + const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, + OType *out_colwise_data_sh, e8m0_t *scales_colwise, float &partial_dbias_colwise) { constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + float thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; - using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; } } - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = + tensor_scales_offset_colwise_base + + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + global_scales_offset_X, local_scales_offset_Y, DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const size_t block_ID = blockIdx.y * gridDim.x + blockIdx.x; - const size_t block_global_offset = - is_single_tensor ? (blockIdx.y * CHUNK_DIM_Y * last_logical_dim + blockIdx.x * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, blockIdx.y, - first_logical_dim, last_logical_dim, offsets_ptr); + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + return thread_amax; +} - const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); - const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); +template +__device__ __forceinline__ float process_rowwise_stage( + const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, + const size_t thread_offset_X_rowwise, const int bank_group, + const size_t scales_offset_Y_rowwise, const size_t scales_offset_X_rowwise, + const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, const size_t cols, + IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_rowwise_data_sh, + e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; - // grouped tensor can be treated as continuous tensor for MXFP8 - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); - // For grouped tensors represented as a single logical tensor, scale swizzle must still be - // computed per tensor (expert) and then concatenated along dim-0. - const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) - ? static_cast(offsets_ptr[tensor_id]) - : tensor_base; + const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + float thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + Vec in_IType[WAVES]; - // In graph-safe paged stashing, the logical shape can include trailing garbage. Skip CTAs that - // map outside the current tensor's valid [rows, cols] region. - if (rows == 0 || cols == 0) { - return; - } - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - return; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } } - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - return; + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } } - } + if constexpr (!std::is_same_v) { + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = - is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = - is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + Vec in; + Vec act_in; - const bool leading_thread = (threadIdx.x == 0); + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } - if (leading_thread && (!is_single_tensor)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); - } - if constexpr (ROWWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_rowwise); + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } } - if constexpr (COLWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_colwise); + } + + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); + return thread_amax; +} - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, + const size_t work_blocks_X, const size_t work_blocks_Y) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } - e8m0_t *const scales_rowwise = - scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = - scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + const bool leading_thread = (threadIdx.x == 0); const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X; @@ -358,11 +655,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -392,375 +684,332 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } + constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; float block_amax = 0.0f; -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, leading_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); - } + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + if (leading_thread) { #pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - leading_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); - } + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); } - ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; + + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. + if constexpr (PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = gridDim.x * gridDim.y; + static_next_block_id = launch_block_id + static_block_stride; + } + bool job_finished = false; + int buff_in = 0; + bool has_prefetched_current_job = true; + + // Prime the pipeline with stage-0 of the first job assigned to this CTA. + { + const JobDescriptor first_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + if (!is_job_valid(first_job, shape_rep, total_work_blocks, offsets_ptr)) { + return; + } + const BlockDescriptor first_block = decode_block(first_job, is_single_tensor, offsets_ptr); + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[first_job.tensor_id]; + const CUtensorMap &tensor_map_act_input = is_single_tensor + ? tensor_map_act_input_static + : g_tensor_maps_act_input[first_job.tensor_id]; + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { #pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = first_block.block_offset_Y + stage_offset_Y; + const size_t global_offset_X = first_block.block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, buff_offset, + shmem_buff_size, barrier, leading_thread); + } + } - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_colwise[i] = elt; - } + // Main work loop: decode current job, run all 32-row stages, schedule/prefetch next job. + while (!job_finished) { + // Decode CTA assignment into logical tensor coordinates and validate bounds. + const JobDescriptor current_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); + if (!current_job_is_valid) { + if (has_prefetched_current_job) { + // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); } + break; + } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - const size_t tensor_base_row = tensor_base_for_scales / cols; - const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; - const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); - } else { - scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const size_t tensor_id = current_job.tensor_id; + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const BlockDescriptor current_block = decode_block(current_job, is_single_tensor, offsets_ptr); + + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = current_block.tensor_base; + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + const size_t block_id_Y = current_block.block_id_Y; + const size_t block_id_X = current_block.block_id_X; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = is_single_tensor + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = is_single_tensor + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); } - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + } -// 3. Scale elements + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { #pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; } } - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + bool prefetched_next_job = false; + // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + bool allow_next_job_prefetch = true; + JobDescriptor prefetch_job = current_job; + BlockDescriptor prefetch_block = current_block; + + if (stage == STAGES - PREFETCH_STAGES) { + if constexpr (PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } + // Next loop iteration exits via current_job_is_valid check. + ctaid_X = 0; + ctaid_Y = static_cast(work_blocks_Y); + allow_next_job_prefetch = false; } + } else { + ctaid_X = -1; + ctaid_Y = -1; } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + if constexpr (!PERSISTENT) { + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + } - Vec in; - Vec act_in; + // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. + // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. + if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { + prefetch_job = decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, + last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, + first_dims_ptr, last_dims_ptr); + allow_next_job_prefetch = + is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); + if (allow_next_job_prefetch) { + prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); + } + } - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { + const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + if (stage >= STAGES - PREFETCH_STAGES) { + prefetched_next_job = true; + } - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; + const size_t global_offset_Y = prefetch_block.block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = prefetch_block.block_offset_X; + const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; + + const CUtensorMap &prefetch_tensor_map_input = + is_single_tensor ? tensor_map_input_static + : g_tensor_maps_input[prefetch_job.tensor_id]; + const CUtensorMap &prefetch_tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static + : g_tensor_maps_act_input[prefetch_job.tensor_id]; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + if ((!is_single_tensor) && (stage == STAGES - PREFETCH_STAGES)) { + fence_acquire_tensormap(&prefetch_tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&prefetch_tensor_map_act_input); } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; } } + prefetch_input_stage(in_sh, act_in_sh, prefetch_tensor_map_input, + prefetch_tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, + shmem_buff_size, barrier, leading_thread); + ptx::fence_proxy_async_shared_cta(); } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, - DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const size_t buff = buff_in; + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_amax = process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } - scales_rowwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -// 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + if constexpr (ROWWISE_SCALING) { + thread_amax = process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, + scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, + out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); } - } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); - // Initiate TMA transfer to copy shared memory to global memory - if (leading_thread) { + // Publish the stage from shared memory into global outputs via TMA. const int global_offset_Y = block_offset_Y + stage_offset_Y; const int global_offset_X = block_offset_X; const int buff_offset = buff * BUFF_DIM; + store_output_stage( + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); + buff_in = (buff_in + 1) % BUFFS_NUM; } - } - - parity ^= 1; + has_prefetched_current_job = prefetched_next_job; - if constexpr (IS_DBIAS) { - if (is_single_tensor) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + for (int i = 0; i < THREADS_Y; ++i) { + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + const int dbias_stride = cols; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; } } - const int dbias_stride = cols; - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } } } @@ -774,7 +1023,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, leading_thread); + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace group_quantize_kernel @@ -782,8 +1036,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel template void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, - const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { + const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, + Tensor *workspace, cudaStream_t stream) { using namespace group_quantize_kernel; checkCuDriverContext(stream); @@ -834,23 +1088,30 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks_X = 0; - size_t blocks_Y = 0; + size_t work_blocks_X = 0; + size_t work_blocks_Y = 0; if (is_single_tensor) { - blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + work_blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + work_blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); } else { - NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - // Only full tiles supported - NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0, - "Last dimension of a grouped tensor should be divisible by 128."); - blocks_Y = 1; - blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + work_blocks_Y = 1; + work_blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + } + + size_t launch_blocks_X = work_blocks_X; + size_t launch_blocks_Y = work_blocks_Y; + if constexpr (PERSISTENT) { + const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); + const size_t static_grid_size = sm_num * TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; + NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); + launch_blocks_X = static_grid_size; + launch_blocks_Y = 1; } - const dim3 grid(blocks_X, blocks_Y); + const dim3 grid(launch_blocks_X, launch_blocks_Y); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -858,7 +1119,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations // Logical shape of a tensor with varying all dims is [1, M*K] if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { NVTE_CHECK(first_logical_dim % 128 == 0, - "First dimension of a grouped tensor should be divisible by 128."); + "First logical dimension of a grouped tensor must be divisible by 128."); } const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); @@ -879,18 +1140,20 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input_tensor."); - NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + std::vector expected_shape_dbias_tensor = {num_tensors, last_logical_dim}; + NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); + + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_cols = last_logical_dim; if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; workspace->data.dtype = DType::kFloat32; return; } @@ -1004,10 +1267,13 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, + work_blocks_Y); if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 06f1c65ce2..854f52c203 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -56,6 +56,7 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the GeLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -76,6 +77,7 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the SiLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -96,6 +98,7 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the ReLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -116,6 +119,7 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the Quick GeLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -136,6 +140,7 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the Squared ReLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -158,6 +163,7 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the GeLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -182,6 +188,7 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the SiLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -206,6 +213,7 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the ReLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -230,6 +238,7 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu /*! \brief Computes the Quick GeLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -254,6 +263,7 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu /*! \brief Computes the Squared ReLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 04712d3003..755052d6dd 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -92,6 +92,7 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea /*! \brief Casts input grouped tensor to MXFP8. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor to be cast. * \param[in,out] output Output grouped MXFP8 tensor. @@ -146,6 +147,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d /*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -161,7 +163,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d * \param[in] stream CUDA stream used for the operation. */ void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. @@ -190,6 +192,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -207,7 +210,8 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. @@ -236,6 +240,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -253,7 +258,8 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. @@ -282,6 +288,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -299,7 +306,8 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. @@ -328,6 +336,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -345,7 +354,8 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. @@ -374,6 +384,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -391,7 +402,8 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, @@ -407,11 +419,11 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str /*! \brief Casts multiple input tensors to quantized output tensors. * - * \param[in] inputs List of input tensors to be cast. - * \param[in,out] outputs List of output quantized tensors. - * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] num_tensors Number of input and output tensors. - * \param[in] stream CUDA stream used for the operation. + * \param[in] inputs List of input tensors to be cast. + * \param[in,out] outputs List of output quantized tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_config, const size_t num_tensors, @@ -420,11 +432,11 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, /*! \brief Casts grouped input tensor to quantized output tensors. * * \param[in] input Input tensor to be cast. - * \param[in,out] outputs Output quantized tensors. - * \param[in] split_sections Split sections of the input tensor. - * \param[in] num_tensors Number of output tensors. + * \param[in,out] outputs Output quantized tensors. + * \param[in] split_sections Split sections of the input tensor. + * \param[in] num_tensors Number of output tensors. * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] stream CUDA stream used for the operation. + * \param[in] stream CUDA stream used for the operation. */ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, size_t num_tensors,