diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 82655a0ce..bb2ecdca6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2171,6 +2171,8 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_accuracy_cutlass( @@ -2178,6 +2180,8 @@ def test_grouped_linear_accuracy_cutlass( num_gemms, bs, model, + recipe, + fp8_model_params, fuse_wgrad_accumulation, delay_wgrad_compute, ): @@ -2187,8 +2191,8 @@ def test_grouped_linear_accuracy_cutlass( num_gemms, bs, model, - None, - False, + recipe, + fp8_model_params, fuse_wgrad_accumulation, False, delay_wgrad_compute, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8d5537368..d489a177a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -254,7 +254,9 @@ else() list(APPEND transformer_engine_cpp_sources fused_attn_rocm/fused_attn.cpp gemm/rocm_gemm.cu - gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp deleted file mode 100644 index def454f86..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ /dev/null @@ -1,338 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#include - -#include -#include "../common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -namespace transformer_engine { -namespace grouped_gemm { - -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; - -// Treat TE tensors as generalized 2D matrices by flattening: -// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. -static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { - // Require at least a matrix (rank >= 2). Higher ranks are flattened. - if (t.shape().size() < 2) - return false; - d0 = static_cast(t.flat_first_dim()); - d1 = static_cast(t.flat_last_dim()); - return true; -} - -// Selects epilogue traits based on whether we are accumulating (D += A*B) or not (D = A*B). -// For accumulate=true, the existing D buffer is passed as a MultiD input tensor and combined -// via element_wise::Add. For accumulate=false, no extra input is needed and PassThrough is used. -template -struct EpilogueTraits { - using DsDataType = ck_tile::tuple<>; - using DsLayout = ck_tile::tuple<>; - using ElemOp = ck_tile::element_wise::PassThrough; -}; -template -struct EpilogueTraits { - using DsDataType = ck_tile::tuple; - using DsLayout = ck_tile::tuple; - using ElemOp = ck_tile::element_wise::Add; -}; - -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; // rowwise data view -} - -// Primus-Turbo-like FP16/BF16 tile configs -// Selection rule: -// if (N % 256 == 0) use 256x256x64 -// else if (N % 128 == 0) use 256x128x64 -// else use 256x128x64 with N padding enabled -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; -}; - -// This class instantiates CK_Tile's grouped GEMM pipeline. -// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. -template -struct Runner{ - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using ET = EpilogueTraits; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, typename ET::DsDataType, AccType, - CType, typename ET::DsLayout, CLayout, - typename ET::ElemOp, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC>>; - - using Kernel = ck_tile::GroupedGemmKernel; -}; - -template -static bool run_grouped_impl(const NVTETensor* A_use, - const NVTETensor* B_use, - NVTETensor* D, - int group_num, - bool transA_use, - bool transB_use, - void* workspace, - size_t workspace_bytes, - hipStream_t stream) -{ - using Kernel = typename Runner::Kernel; - - const size_t needed = Kernel::GetWorkSpaceSize(group_num); - if (!workspace || workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, - ", available bytes=", workspace_bytes, ". Falling back."); - return false; - } - - // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. - using HostArgs = std::conditional_t, - ck_tile::GroupedGemmHostArgs<0>>; - - thread_local std::vector descs; - descs.clear(); - descs.reserve(group_num); - - for (int i = 0; i < group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(A_use[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(B_use[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); - return false; - } - - const int64_t M = transA_use ? Ad1 : Ad0; - const int64_t K = transA_use ? Ad0 : Ad1; - const int64_t N = transB_use ? Bd0 : Bd1; - const int64_t Kb = transB_use ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - return false; - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - return false; - } - - // Leading dimensions under the flattened-contiguous interpretation - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - if constexpr (Accumulate) { - // MultiD: E = Add(A@B, D1). D1 and E point to the same buffer for in-place accumulation. - descs.emplace_back( - a.dptr, b.dptr, - std::array{d.dptr}, // D1 = existing D contents (read) - d.dptr, // E = same buffer (write) - 1, M, N, K, - stride_A, stride_B, - std::array{stride_E}, - stride_E); - } else { - descs.emplace_back( - a.dptr, b.dptr, - std::array{}, - d.dptr, - 1, M, N, K, - stride_A, stride_B, - std::array{}, - stride_E); - } - } - - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); - return false; - } - - HIP_CHECK_ERROR(hipMemcpyAsync(workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - stream)); - - const ck_tile::stream_config s{stream}; - const dim3 blocks = Kernel::BlockSize(); - - ck_tile::launch_kernel( - s, - ck_tile::make_kernel<1>( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(workspace), - group_num)); - return true; -} - -} // namespace grouped_gemm -} // namespace transformer_engine - -bool ck_tile_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, - int group_num, - bool transA, - bool transB, - NVTETensor* workspace, - bool accumulate, - hipStream_t stream) -{ - if (group_num <= 0) - return true; - - using namespace transformer_engine; - using namespace transformer_engine::grouped_gemm; - - // Workspace pointer + bytes - void* ws_ptr = nullptr; - size_t ws_bytes = 0; - if (workspace) { - auto* ws_te = convertNVTETensorCheck(*workspace); - ws_ptr = ws_te->data.dptr; - ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); - } - - // Normalize similar to upstream - // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 - // I.e., swap A and B, as well as transa and transb. - const NVTETensor* A_use = B; - const NVTETensor* B_use = A; - const bool transA_use = transB; - const bool transB_use = transA; - - const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); - - // Get N from D[0] (assume uniform N across groups) - int64_t ref_d0 = 0, ref_d1 = 0; - Tensor* D0_te = convertNVTETensorCheck(D[0]); - if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); - return false; - } - const ck_tile::index_t N = static_cast(ref_d1); - - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { - using T = typename TETypeToCKType::type; - - auto run_with_tilecfg = [&](auto tile_tag) -> bool { - using TileCfgSel = decltype(tile_tag); - - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; - - if (accumulate) { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } else { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } - }); - }); - }; - - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. - if ((N % 256) == 0) { - return run_with_tilecfg(TileCfg_256x256x64{}); - } else if ((N % 128) == 0) { - return run_with_tilecfg(TileCfg_256x128x64{}); - } else { - return run_with_tilecfg(TileCfg_256x128x64_padding{}); - } - }); -} diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp new file mode 100644 index 000000000..a61f33b8c --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -0,0 +1,141 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" + +bool ck_tile_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) { + if (group_num <= 0) { + return true; + } + + using namespace transformer_engine; + using namespace transformer_engine::grouped_gemm; + + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + // Normalize similar to upstream + // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + bool transA_use = transB; + bool transB_use = transA; + bool use_b_columnwise_data = false; + + const auto caller_a_dtype = convertNVTETensorCheck(A[0])->dtype(); + const auto caller_b_dtype = convertNVTETensorCheck(B[0])->dtype(); + + const bool caller_a_is_fp8 = + caller_a_dtype == DType::kFloat8E4M3 || caller_a_dtype == DType::kFloat8E5M2; + const bool caller_b_is_fp8 = + caller_b_dtype == DType::kFloat8E4M3 || caller_b_dtype == DType::kFloat8E5M2; + + // Currently the accumulate path is only supported on fp16 + if (accumulate && (caller_a_is_fp8 || caller_b_is_fp8)) + return false; + + // Handle pathological NN case during fp8 dX GEMM by reading W columnwise and re-formulating as NT + if (!transA_use && !transB_use && caller_a_is_fp8 && caller_b_is_fp8) { + auto* B0_te = convertNVTETensorCheck(B_use[0]); + if (B0_te->has_columnwise_data()) { + use_b_columnwise_data = true; + transB_use = true; + } + } + + const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); + + Tensor* D0_te = convertNVTETensorCheck(D[0]); + const auto d_dtype = D0_te->dtype(); + + Tensor* A0_te = convertNVTETensorCheck(A_use[0]); + Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + + int64_t a0 = 0, a1 = 0; + int64_t b0 = 0, b1 = 0; + int64_t d0 = 0, d1 = 0; + + if (!get_flat_2d_dims(*A0_te, a0, a1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A_use[0]"); + return false; + } + + if (use_b_columnwise_data) { + if (B0_te->columnwise_data.shape.size() < 2) { + NVTE_ERROR("ck_tile_grouped_gemm: expected columnwise_data rank>=2 for B_use[0]"); + return false; + } + b0 = static_cast(B0_te->columnwise_data.shape[B0_te->columnwise_data.shape.size() - 2]); + b1 = static_cast(B0_te->columnwise_data.shape[B0_te->columnwise_data.shape.size() - 1]); + } else { + if (!get_flat_2d_dims(*B0_te, b0, b1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B_use[0]"); + return false; + } + } + + if (!get_flat_2d_dims(*D0_te, d0, d1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); + return false; + } + + const int64_t m = transA_use ? a1 : a0; + const int64_t kA = transA_use ? a0 : a1; + + const int64_t kB = transB_use ? b1 : b0; + const int64_t n = transB_use ? b0 : b1; + + if (kA != kB) { + NVTE_ERROR("ck_tile_grouped_gemm: normalized GEMM K mismatch: op(A_use) is ", + m, "x", kA, ", op(B_use) is ", kB, "x", n); + return false; + } + + if (d0 != m || d1 != n) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch for normalized GEMM. " + "D is ", d0, "x", d1, " but expected ", m, "x", n); + return false; + } + + GroupedGemmRunContext ctx = { + A_use, + B_use, + D, + static_cast(n), + group_num, + transA_use, + transB_use, + ws_ptr, + ws_bytes, + stream, + use_b_columnwise_data, + accumulate}; + + if (ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx)) { + return true; + } + + if (ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx)) { + return true; + } + + return false; +} diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h similarity index 100% rename from transformer_engine/common/gemm/ck_grouped_gemm.h rename to transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h new file mode 100644 index 000000000..8b8b76152 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -0,0 +1,117 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include "../../common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +// Selects epilogue traits based on whether we are accumulating (D += A*B) or not (D = A*B). +// For accumulate=true, the existing D buffer is passed as a MultiD input tensor and combined +// via element_wise::Add. For accumulate=false, no extra input is needed and PassThrough is used. +template +struct EpilogueTraits { + using DsDataType = ck_tile::tuple<>; + using DsLayout = ck_tile::tuple<>; + using ElemOp = ck_tile::element_wise::PassThrough; +}; +template +struct EpilogueTraits { + using DsDataType = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + using ElemOp = ck_tile::element_wise::Add; +}; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +struct GroupedGemmRunContext { + const NVTETensor* A = nullptr; + const NVTETensor* B = nullptr; + NVTETensor* D = nullptr; + int64_t N = 0; + + int group_num = 0; + bool transA = false; + bool transB = false; + + void* workspace = nullptr; + size_t workspace_bytes = 0; + hipStream_t stream = nullptr; + + bool use_b_columnwise_data = false; + bool accumulate = false; +}; + +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +class RunnerInterface { +public: + virtual ~RunnerInterface() = default; + virtual bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) = 0; +}; + +std::unique_ptr make_fp8_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +std::unique_ptr make_fp16_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp new file mode 100644 index 000000000..6dfb4561a --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -0,0 +1,159 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +template +std::unique_ptr make_fp16_runner_typed(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + if (ctx.N % 256 == 0) { + using TileCfg = TileCfg_256x256x64; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CLayout, + TileCfg, true>; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CLayout, + TileCfg, false>; + runner = std::make_unique(); + } + } else if (ctx.N % 128 == 0) { + using TileCfg = TileCfg_256x128x64; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CLayout, + TileCfg, true>; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CLayout, + TileCfg, false>; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x64_padding; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CLayout, + TileCfg, true>; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CLayout, + TileCfg, false>; + runner = std::make_unique(); + } + } + }); + return runner; +} + +std::unique_ptr make_fp16_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else { + return nullptr; + } + return nullptr; +} + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + + auto runner = make_fp16_runner( + a_dtype, b_dtype, d_dtype, ctx); + + if (!runner) { + return false; + } + + return runner->run(s, ctx); +} + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h new file mode 100644 index 000000000..1133d53bf --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +// ------------------------- +// Tile configs: FP16/BF16 +// ------------------------- + +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +template +class GroupedGemmRunner : public RunnerInterface { +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, + GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using ET = EpilogueTraits; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, typename ET::DsDataType, AccType, + CType, typename ET::DsLayout, CLayout, + typename ET::ElemOp, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC>>; + + using Kernel = ck_tile::GroupedGemmKernel; + + // GroupedGemmHostArgs<1> for the MultiD accumulate path, <0> for the overwrite path. + using HostArgs = std::conditional_t, + ck_tile::GroupedGemmHostArgs<0>>; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return {}; + } + + thread_local std::vector descs; + descs.clear(); + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + if constexpr(Accumulate) { + descs.emplace_back( + a.dptr, b.dptr, + std::array{d.dptr}, // D1 = existing D contents (read) + d.dptr, // E = same buffer (write) + 1, M, N, K, + stride_A, stride_B, + std::array{stride_E}, + stride_E); + } else { + descs.emplace_back( + a.dptr, b.dptr, + std::array{}, + d.dptr, + 1, M, N, K, + stride_A, stride_B, + std::array{}, + stride_E); + } + } + + return descs; + }; + + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp new file mode 100644 index 000000000..83c8bfb95 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -0,0 +1,171 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +enum class GPUArch { + GFX942, + GFX950, + UNKNOWN +}; + +GPUArch detect_gpu_arch() { + int device = 0; + HIP_CHECK_ERROR(hipGetDevice(&device)); + + hipDeviceProp_t props{}; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); + + if (props.major == 9 && props.minor == 4) { + return GPUArch::GFX942; + } + if (props.major == 9 && props.minor == 5) { + return GPUArch::GFX950; + } + return GPUArch::UNKNOWN; +} + +template +struct FP8TileCfg; + +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_32x32x16_2x2x1; +}; + +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_16x16x128_2x2x1; +}; + +template +std::unique_ptr make_fp8_runner_typed(DType d_dtype, + const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CTypeLayout = RowMajor; + using TileCfg = typename FP8TileCfg::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + using Runner = QuantGroupedGemmRunner< + AType, BType, CType, + ALayout, BLayout, CTypeLayout, + TileCfg, ck_tile::memory_operation_enum::set>; + runner = std::make_unique(); + }); + + return runner; +} + +template +std::unique_ptr make_fp8_runner_for_layout(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + switch (a_dtype) { + case DType::kFloat8E4M3: + switch (b_dtype) { + case DType::kFloat8E4M3: + return make_fp8_runner_typed(d_dtype, ctx); + case DType::kFloat8E5M2: + return make_fp8_runner_typed(d_dtype, ctx); + default: + return nullptr; + } + + case DType::kFloat8E5M2: + switch (b_dtype) { + case DType::kFloat8E4M3: + return make_fp8_runner_typed(d_dtype, ctx); + case DType::kFloat8E5M2: + return make_fp8_runner_typed(d_dtype, ctx); + default: + return nullptr; + } + + default: + return nullptr; + } +} + +template +std::unique_ptr make_fp8_runner_impl(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + return make_fp8_runner_for_layout(a_dtype, b_dtype, d_dtype, ctx); + } + + if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + return make_fp8_runner_for_layout(a_dtype, b_dtype, d_dtype, ctx); + } + + if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + return make_fp8_runner_for_layout(a_dtype, b_dtype, d_dtype, ctx); + } + + return nullptr; +} + +std::unique_ptr make_fp8_runner_gfx942(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); +} + +std::unique_ptr make_fp8_runner_gfx950(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + return make_fp8_runner_impl(a_dtype, b_dtype, d_dtype, ctx); +} + +std::unique_ptr make_fp8_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + switch (detect_gpu_arch()) { + case GPUArch::GFX942: + return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX950: + return make_fp8_runner_gfx950(a_dtype, b_dtype, d_dtype, ctx); + default: + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + return nullptr; + } +} + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; + + auto runner = make_fp8_runner(a_dtype, b_dtype, d_dtype, ctx); + if (!runner) { + return false; + } + + return runner->run(s, ctx); +} + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h new file mode 100644 index 000000000..375ab8b6b --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h @@ -0,0 +1,284 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +// Use gfx950-specific tile parameters only for gfx950 device compilation. +// Host code and all other architectures use the default config. +#if defined(__HIP_DEVICE_COMPILE__) && defined(__gfx950__) +struct TileCfg_128x128x128_32x32x16_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; +#else +struct TileCfg_128x128x128_32x32x16_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; +#endif + +struct TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + +template +class QuantGroupedGemmRunner : public RunnerInterface { +public: + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using AQLayout = RowMajor; + using BQLayout = RowMajor; + + using UniversalTraits = + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>; + + using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + using HostArgs = ck_tile::QuantGroupedGemmHostArgs; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return {}; + } + + std::vector descs; + descs.reserve(ctx.group_num); + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& d = data_view(*D_te); + + const transformer_engine::SimpleTensor* b_src = nullptr; + if (ctx.use_b_columnwise_data) { + if (!B_te->has_columnwise_data()) { + NVTE_ERROR("ck_tile_grouped_gemm: ctx.use_b_columnwise_data=true but columnwise_data is absent."); + } + b_src = &B_te->columnwise_data; + } else { + b_src = &B_te->data; + } + + const auto& b = *b_src; + + int64_t Ad0 = 0, Ad1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected A and D to be rank>=2."); + } + + if (b.shape.size() < 2) { + NVTE_ERROR("ck_tile_grouped_gemm: expected chosen B source to be rank>=2."); + } + + int64_t Bd0 = static_cast(b.shape[b.shape.size() - 2]); + int64_t Bd1 = static_cast(b.shape[b.shape.size() - 1]); + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, + " op(B)=", Kb, "x", N, + " raw A=", Ad0, "x", Ad1, + " raw B=", Bd0, "x", Bd1, + " use_b_columnwise_data=", static_cast(ctx.use_b_columnwise_data), + " transA=", static_cast(ctx.transA), + " transB=", static_cast(ctx.transB)); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, + ", expected=", M, "x", N); + } + + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + + descs.emplace_back( + a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); + } + + return descs; + } + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + } +}; + +} // namespace grouped_gemm +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 35da3075c..64dd35479 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -32,7 +32,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "./cutlass_grouped_gemm.cuh" #else -#include "ck_grouped_gemm.h" +#include "ck_grouped_gemm/ck_grouped_gemm.h" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -1140,9 +1140,14 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - return (A_dt == B_dt) && (A_dt == D_dt) && - (A_dt == transformer_engine::DType::kFloat16 || - A_dt == transformer_engine::DType::kBFloat16); + return ( + (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) + ) || + ( + (A_dt == B_dt) && (A_dt == D_dt) && + (A_dt == transformer_engine::DType::kFloat16 || + A_dt == transformer_engine::DType::kBFloat16) + ); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype);