From b637af145ea68de66c02bc29a92d27baa3b8e2e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= Date: Fri, 19 Dec 2025 13:10:19 +0000 Subject: [PATCH 01/10] device struct implementation --- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 995 ++++++++++++++++++ .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 43 +- .../CMakeLists.txt | 6 +- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance cpp | 111 ++ ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 111 ++ ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 111 ++ ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 17 + ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 515 +++++++++ profiler/src/CMakeLists.txt | 20 + ...rofile_grouped_gemm_multi_abd_fixed_nk.cpp | 206 ++++ test/CMakeLists.txt | 2 +- 11 files changed, 2134 insertions(+), 3 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp new file mode 100644 index 0000000000..546256ac79 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -0,0 +1,995 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Can be shared between multiple device implementations.... +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + // Binary search lookup to find which group this block is part of + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index + // and thus needs a non-const reference. It's also not feasible to store this in global + // memory as different threads would be writing different K values to the same arg struct + auto karg = gemm_desc_ptr[group_id].karg_; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; + + // Tile index first dimension is the K batch + auto tile_index = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(static_cast(p_shared), + splitk_batch_offset, + karg, + block_2_ctile_map, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = gemm_descs_const; + ignore = group_count; +#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) +} + +template // ??? +struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + : public DeviceGroupedGemmMultiABDFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // Note: Pass multiple layout but then using only the first one + // This is to replicate xdl functionality but it should be extended + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + using CGridDesc_M_N = + remove_cvref_t( + 1, 1, 1, 1, 1))>; + + // Move OffsettedBlockToCTileMapMLoops and BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops to helper hpp? + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + static constexpr index_t DefaultKBatch = 1; + using KernelArgument = typename GridwiseGemm::Argument; + + template + struct GemmTransKernelArgBase + { + KernelArgument_ karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + using GemmTransKernelArg = GemmTransKernelArgBase; + + static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) + { + index_t k_grain = karg.KBatch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / karg.KBatch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op, + index_t kbatch) + : group_count_{ck::type_convert(gemm_descs.size())}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + grouped_gemm_kernel_args_dev{nullptr}, + gemm_kernel_host_args_{nullptr}, + grid_size_{0}, + k_batch_{kbatch} + { + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + ((NumDTensor == 0 && p_Ds.size() == 0) || group_count_ == ck::type_convert(p_Ds.size())) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + const index_t fixed_N = gemm_descs[0].N_; + const index_t fixed_K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(N != fixed_N || K != fixed_K) // M? + { + throw std::runtime_error("wrong! N or K are not fixed across GEMM groups"); + } + + a_mtx_mraw_kraw_.emplace_back(M, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + // pointer + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + + static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + + const index_t StrideE = gemm_descs[i].stride_C_; + + if(gemm_descs[i].stride_As_.size() != NumATensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); + } + + static_for<0, NumATensor, 1>{}( + [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; }); + + if(gemm_descs[i].stride_Bs_.size() != NumBTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); + } + + static_for<0, NumBTensor, 1>{}( + [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; }); + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + static_for<0, NumDTensor, 1>{}( + [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + M, m_padded, N, n_padded, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp_; + + grid_size_ += grid_size_grp_; + + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument(p_as_grid, + p_bs_grid, + p_ds_grid, + type_convert(p_Es[i]), + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false); + + gemm_desc_kernel_arg_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + // group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + group_count_ * gemm_descs[0].M_, + group_count_ * gemm_descs[0].M_, + gemm_descs[0].N_, + gemm_descs[0].N_, + gemm_descs[0].stride_C_); + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + void UpdateKBatch(index_t) {} + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + void* gemm_kernel_host_args_; + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float RunImp(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + + bool all_have_kbatch_gt_one = arg.gemm_desc_kernel_arg_[0].karg_.KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[0].karg_); + + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) + { + const auto& karg = reinterpret_cast( + arg.gemm_desc_kernel_arg_[i].karg_); + + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.KBatch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); + not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); + } + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hip_check_error(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + cpy_stream)); + + hip_check_error(hipEventRecord(cpy_event, cpy_stream)); + + hip_check_error(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + + hip_check_error( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_desc_kernel_arg_) + { + + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_e_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size()); + }; + + // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return RunImp(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by + // vector load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} + // layout, thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd + // instruction that supports bf16 and we cannot use splitk because of that + if constexpr(std::is_same::value) + { + supported = supported & (arg.k_batch_ == 1); + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Wmma_Fixed_Nk" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetElementwiseOps(Argument& arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + arg.a_element_op_ = a_element_op; + arg.b_element_op_ = b_element_op; + arg.c_element_op_ = c_element_op; + } + + // polymorphic + void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) const override + { + + SetElementwiseOps( + *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * sizeof(GroupedGemmMultiABDKernelArgument); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* /*p_arg*/, index_t /*kbatch*/) const override + { + throw std::runtime_error("??? figure out later"); + } + + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 6d97ec3a05..2d75f20670 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" namespace ck { namespace tensor_operation { @@ -21,6 +20,8 @@ using Multiply = ck::tensor_operation::element_wise::Multiply; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +#if defined(CK_USE_XDL) // RRR void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( std::vector, @@ -179,6 +180,23 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan PassThrough, Multiply, PassThrough>>>& instances); +#endif + +#if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); +#endif // CK_USE + // GEMM + Add + Gelu template > op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -246,6 +265,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL return op_ptrs; } @@ -289,6 +309,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -317,6 +338,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL return op_ptrs; } @@ -360,6 +382,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -388,6 +411,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL return op_ptrs; } @@ -431,6 +455,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -459,6 +484,22 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index 9d9a0e691c..631020dc2d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -1,13 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp + + # device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp + # device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp new file mode 100644 index 0000000000..d483a828a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp @@ -0,0 +1,111 @@ +// // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// // SPDX-License-Identifier: MIT + +// #include + +// #include "ck/ck.hpp" +// #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +// #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +// #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +// #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +// #include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp" + +// namespace ck { +// namespace tensor_operation { +// namespace device { +// namespace instance { + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple, +// EDataType, +// AElementOp, +// BElementOp, +// AddFastGelu>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< +// ck::Tuple, +// ck::Tuple, +// AddFastGelu, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple, +// EDataType, +// AElementOp, +// BElementOp, +// Add>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< +// ck::Tuple, +// ck::Tuple, +// Add, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple<>, +// EDataType, +// AElementOp, +// BElementOp, +// PassThrough>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< +// ck::Tuple<>, +// ck::Tuple<>, +// PassThrough, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple<>, +// EDataType, +// AElementOp, +// BElementOp, +// FastGelu>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< +// ck::Tuple<>, +// ck::Tuple<>, +// FastGelu, +// GemmMNKPadding>{}); +// } + +// } // namespace instance +// } // namespace device +// } // namespace tensor_operation +// } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..a4b0d776c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,111 @@ +// // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// // SPDX-License-Identifier: MIT + +// #include + +// #include "ck/ck.hpp" +// #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +// #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +// #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +// #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +// #include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp" + +// namespace ck { +// namespace tensor_operation { +// namespace device { +// namespace instance { + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple, +// EDataType, +// AElementOp, +// BElementOp, +// AddFastGelu>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< +// ck::Tuple, +// ck::Tuple, +// AddFastGelu, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple, +// EDataType, +// AElementOp, +// BElementOp, +// Add>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< +// ck::Tuple, +// ck::Tuple, +// Add, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple<>, +// EDataType, +// AElementOp, +// BElementOp, +// PassThrough>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< +// ck::Tuple<>, +// ck::Tuple<>, +// PassThrough, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple<>, +// EDataType, +// AElementOp, +// BElementOp, +// FastGelu>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< +// ck::Tuple<>, +// ck::Tuple<>, +// FastGelu, +// GemmMNKPadding>{}); +// } + +// } // namespace instance +// } // namespace device +// } // namespace tensor_operation +// } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..80c13f4ef2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp" // TODO: find a final spot for instances. + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple, +// EDataType, +// AElementOp, +// BElementOp, +// AddFastGelu>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< +// ck::Tuple, +// ck::Tuple, +// AddFastGelu, +// GemmMNKPadding>{}); +// } + +// void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple, +// EDataType, +// AElementOp, +// BElementOp, +// Add>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< +// ck::Tuple, +// ck::Tuple, +// Add, +// GemmMNKPadding>{}); +// } + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +// void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( +// std::vector, +// ELayout, +// AsDataType, +// BsDataType, +// ck::Tuple<>, +// EDataType, +// AElementOp, +// BElementOp, +// FastGelu>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< +// ck::Tuple<>, +// ck::Tuple<>, +// FastGelu, +// GemmMNKPadding>{}); +// } + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 95365c82e7..2843918693 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -83,6 +84,22 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> +// DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< ck::Tuple, ck::Tuple, ck::Tuple, Row, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple, BF16, PassThrough, PassThrough, Add, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp new file mode 100644 index 0000000000..83052acbb3 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -0,0 +1,515 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +auto reserveVector(std::size_t size) +{ + std::vector vec; + vec.reserve(size); + return vec; +} + +template +bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, + int init_method, + bool /*do_log*/, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + int StrideE, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideDs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + auto generateInputTupleA = [&](std::size_t g) { + if constexpr(NumATensor == 0) + { + return ck::Tuple<>(); + } + else + { + using ALayout = remove_cvref_t{}, AsLayout>>; + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ms[g], Ks[g], StrideAs[g], ALayout{})); + }, + Number{}); + } + }; + auto generateInputTupleB = [&](std::size_t g) { + if constexpr(NumBTensor == 0) + { + return ck::Tuple<>(); + } + else + { + using BLayout = remove_cvref_t{}, BsLayout>>; + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ks[g], Ns[g], StrideBs[g], BLayout{})); + }, + Number{}); + } + }; + auto generateInputTupleD = [&](std::size_t g) { + if constexpr(NumDTensor == 0) + { + return ck::Tuple<>(); + } + else + { + using DLayout = remove_cvref_t{}, DsLayout>>; + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ms[g], Ns[g], StrideDs[g], DLayout{})); + }, + Number{}); + } + }; + + using InputTupleA = decltype(generateInputTupleA(0)); + using InputTupleB = decltype(generateInputTupleB(0)); + using InputTupleD = decltype(generateInputTupleD(0)); + + auto g_as_m_k = reserveVector(group_count); + auto g_bs_k_n = reserveVector(group_count); + auto g_ds_m_n = reserveVector(group_count); + auto g_e_m_n_host_results = reserveVector>(group_count); + auto g_e_m_n_device_results = reserveVector>(group_count); + // int sum_of_m = 0; + + for(std::size_t g = 0; g < group_count; g++) + { + // sum_of_m += Ms[g]; + + auto as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); + auto bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); + auto ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); + + g_e_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE, ELayout{}))); + g_e_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE, ELayout{}))); + + // if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + // { + // std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + // << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + // << "]:" << c_m_n_device_results[i].mDesc << std::endl; + // } + + // static_for<0, NumATensor, 1>{}( + // [&](auto i) { std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; }); + // static_for<0, NumBTensor, 1>{}( + // [&](auto i) { std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; }); + // static_for<0, NumDTensor, 1>{}( + // [&](auto i) { std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; }); + // std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + break; + default: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + }); + } + } + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CDEElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector> g_as_device_buf(group_count); + std::vector> g_bs_device_buf(group_count); + std::vector> g_ds_device_buf(group_count); + std::vector g_e_device_buf(group_count); + + std::vector> g_as_device_view(group_count); + std::vector> g_bs_device_view(group_count); + std::vector> g_ds_device_view(group_count); + std::vector g_e_device_view(group_count); + + auto gemm_descs = reserveVector(group_count); + + auto grouped_gemm_kernel_args_host = + reserveVector>( + group_count); + + for(std::size_t g = 0; g < group_count; g++) + { + std::array as_stride; + std::array bs_stride; + std::array ds_stride; + + auto& as_m_k = g_as_m_k[g]; + auto& as_device_buf = g_as_device_buf[g]; + auto& as_device_view = g_as_device_view[g]; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_device_buf[i] = std::make_unique(sizeof(ADataType) * + as_m_k(i).mDesc.GetElementSpaceSize()); + as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); + as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); + as_stride[i] = StrideAs[g]; + }); + + auto& bs_k_n = g_bs_k_n[g]; + auto& bs_device_buf = g_bs_device_buf[g]; + auto& bs_device_view = g_bs_device_view[g]; + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_device_buf[i] = std::make_unique(sizeof(BDataType) * + bs_k_n(i).mDesc.GetElementSpaceSize()); + bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data()); + bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer(); + bs_stride[i] = StrideBs[g]; + }); + + auto& ds_m_n = g_ds_m_n[g]; + auto& ds_device_buf = g_ds_device_buf[g]; + auto& ds_device_view = g_ds_device_view[g]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + return std::make_unique(sizeof(DDataType) * + ds_m_n(i).mDesc.GetElementSpaceSize()); + ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); + ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); + ds_stride[i] = StrideDs[g]; + }); + + gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ + Ms[g], + Ns[g], + Ks[g], + std::vector(as_stride.begin(), as_stride.end()), + std::vector(bs_stride.begin(), as_stride.end()), + std::vector(ds_stride.begin(), as_stride.end()), + StrideE}); + + grouped_gemm_kernel_args_host.push_back({as_device_view, + bs_device_view, + ds_device_view, + g_e_device_buf[g]->GetDeviceBuffer(), + Ms[g], + Ns[g], + Ks[g], + std::move(as_stride), + std::move(bs_stride), + std::move(ds_stride), + StrideE}); + } + + using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + auto p_ds = std::vector>{}; + + // if(do_verification) + // { + // for(std::size_t i = 0; i < gemm_descs.size(); i++) + // { + // using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + // auto ref_gemm = ReferenceGemmInstance{}; + // auto ref_invoker = ref_gemm.MakeInvoker(); + + // auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + // b_k_n[i], + // c_m_n_host_results[i], + // a_element_op, + // b_element_op, + // c_element_op); + + // ref_invoker.Run(ref_argument); + // } + // } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(g_as_device_view, + g_bs_device_view, + g_ds_device_view, + g_e_device_view, + gemm_descs, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + DeviceMem grouped_gemm_kernel_args_dev( + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_host.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + + if(kbatch > 0) + { + kbatch_list = {kbatch}; + } + + for(std::size_t j = 0; j < kbatch_list.size(); j++) + { + + auto kbatch_curr = kbatch_list[j]; + + gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // for(std::size_t i = 0; i < gemm_descs.size(); i++) + // c_device_buf[i]->SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + // if(do_verification) + // { + // bool instance_pass = true; + // for(std::size_t i = 0; i < gemm_descs.size(); i++) + // { + + // c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + // if(std::is_same_v && kbatch_curr > 1) + // { + // instance_pass = + // instance_pass && ck::utils::check_err(c_m_n_device_results[i], + // c_m_n_host_results[i], + // "Error: Incorrect results!", + // 0.06); + // } + // else + // { + // instance_pass = + // instance_pass && ck::utils::check_err(c_m_n_device_results[i], + // c_m_n_host_results[i]); + // } + + // if(do_log) + // { + // LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + // << std::endl; + // LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") + // << std::endl; + // LogRangeAsType( + // std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + // << std::endl; + // LogRangeAsType( + // std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + // << std::endl; + // } + // } + + // std::cout << "Instance: " << gemm_name << " verification " + // << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + // pass = pass && instance_pass; + // } + + /*float ave_time =*/ invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + // if(time_kernel) + // { + // std::size_t flop = 0, num_btype = 0; + // for(std::size_t i = 0; i < gemm_descs.size(); i++) + // { + // flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + // num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + // sizeof(BDataType) * Ks[i] * Ns[i] + + // sizeof(CDataType) * Ms[i] * Ns[i]; + // } + + // float tflops = static_cast(flop) / 1.E9 / ave_time; + + // float gb_per_sec = num_btype / 1.E6 / ave_time; + // std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + // << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + // << kbatch_curr << std::endl; + + // if(tflops > best_tflops) + // { + // best_gemm_name = gemm_name; + // best_tflops = tflops; + // best_ave_time = ave_time; + // best_gb_per_sec = gb_per_sec; + // best_kbatch = kbatch_curr; + // } + // } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 3379fd15d1..ad41bc797d 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -58,6 +58,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_OPS profile_grouped_gemm_multi_abd_fixed_nk.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) @@ -295,3 +296,22 @@ message(VERBOSE "ckProfiler libs: ${PROFILER_LIBS}") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS}) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) + +## Defining specific operation targets + +macro(define_profiler_target NAME SOURCES LIBS) + add_executable(${NAME} profiler.cpp ${SOURCES}) + target_compile_options(${NAME} PRIVATE -Wno-global-constructors) + + if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) + target_compile_options(${NAME} PRIVATE --offload-compress) + endif() + + target_link_libraries(${NAME} PRIVATE utility getopt::getopt ${LIBS}) + + rocm_install(TARGETS ${NAME} COMPONENT profiler) +endmacro() + +define_profiler_target(ckProfiler_fixed_nk + "profile_grouped_gemm_multi_abd_fixed_nk.cpp" + "device_grouped_gemm_fixed_nk_multi_abd_instance") diff --git a/profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp b/profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp new file mode 100644 index 0000000000..16282a3e7a --- /dev/null +++ b/profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN // 2 +}; + +enum struct GemmDataType +{ + BF16_I8_BF16 // 0 +}; + +#define OP_NAME "grouped_gemm_multi_abd_fixed_nk" +#define OP_DESC "Grouped GEMM Multi ABD Fixed NK" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +int profile_grouped_gemm_multi_abd_fixed_nk(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: bf16@int8; 1: fp16; 2: fp16@fp8; 3: fp16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" + << " 1: A[m, k] * B[n, k] = C[m, n];\n" + << " 2: A[k, m] * B[k, n] = C[m, n];)\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg15: kbatch value (default 1)\n" + << "optional:\n" + << "arg16: number of warm-up cycles (default 1)\n" + << "arg17: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + // const auto data_type = static_cast(std::stoi(argv[2])); + // const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideDs = argToIntArray(argv[13]); + const auto StrideE = StrideDs.at(0); + const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); + } + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, + ck::Tuple, + ck::Tuple<>, + float, + float, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + Row>(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideE, + kbatch, + n_warmup, + n_iter); + +// #if defined(CK_ENABLE_INT8) +// #if defined(CK_ENABLE_BF16) +// if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::KM_KN_MN) +// { +// ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, +// ck::Tuple, +// ck::Tuple<>, +// float, +// float, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple<>, +// Row>(do_verification, +// init_method, +// do_log, +// time_kernel, +// Ms, +// Ns, +// Ks, +// StrideAs, +// StrideBs, +// StrideDs, +// StrideE, +// kbatch, +// n_warmup, +// n_iter); +// } +// else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) +// { +// ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, +// ck::Tuple, +// ck::Tuple, +// float, +// float, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple, +// Row>(do_verification, +// init_method, +// do_log, +// time_kernel, +// Ms, +// Ns, +// Ks, +// StrideAs, +// StrideBs, +// StrideDs, +// StrideE, +// kbatch, +// n_warmup, +// n_iter); +// } +// else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) +// { +// ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, +// ck::Tuple, +// ck::Tuple, +// float, +// float, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple, +// Row>(do_verification, +// init_method, +// do_log, +// time_kernel, +// Ms, +// Ns, +// Ks, +// StrideAs, +// StrideBs, +// StrideDs, +// StrideE, +// kbatch, +// n_warmup, +// n_iter); +// } +// #endif // CK_ENABLE_BF16 +// #endif // CK_ENABLE_INT8 +// else +// { +// throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); +// } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_multi_abd_fixed_nk); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9fee3b5697..488a15a11e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -251,7 +251,7 @@ function(add_gtest_executable TEST_NAME) endfunction() add_compile_options(-Wno-c++20-extensions) -add_subdirectory(ck_tile) +# add_subdirectory(ck_tile) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util) From 4b3235335acd13a750842a69a618c814187d2dad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= Date: Fri, 9 Jan 2026 13:55:27 +0000 Subject: [PATCH 02/10] added xdl grouped multi abd fixed nk testing --- .../cpu/reference_gemm_multi_abd.hpp | 204 +++++++++ ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 8 +- ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 8 +- .../profiler/profile_gemm_multi_abd_impl.hpp | 68 +-- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 401 +++++++++--------- test/grouped_gemm/CMakeLists.txt | 6 + .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 304 +++++++++++++ 7 files changed, 744 insertions(+), 255 deletions(-) create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp create mode 100644 test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp new file mode 100644 index 0000000000..aef3c8a45c --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp @@ -0,0 +1,204 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/utility/functional4.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// this function is also defined in CK but because of the way we use it in +// profile_gemm_multi_impl, it requires the arguments to not be const +template +auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + +template +struct ReferenceGemmMultiABD : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const AsTensorTuple& as_m_k, + const BsTensorTuple& bs_k_n, + const DsTensorTuple& ds_m_n, + Tensor& e_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : as_m_k_{as_m_k}, + bs_k_n_{bs_k_n}, + ds_m_n_{ds_m_n}, + e_m_n_{e_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const AsTensorTuple& as_m_k_; + const BsTensorTuple& bs_k_n_; + const DsTensorTuple& ds_m_n_; + Tensor& e_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmMultiABD::Argument; + + float Run(const Argument& arg) + { + static constexpr index_t NumATensor = AsTensorTuple::Size(); + static constexpr index_t NumBTensor = BsTensorTuple::Size(); + static constexpr index_t NumDTensor = DsTensorTuple::Size(); + + const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0]; + const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1]; + const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1]; + + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(arg.a_element_op_, data_refs); + } + } + + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(arg.b_element_op_, data_refs); + } + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(arg.cde_element_op_, data_refs); + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const AsTensorTuple& as_m_k, + const BsTensorTuple& bs_k_n, + const DsTensorTuple& ds_m_n, + Tensor& e_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmMultiABD" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index 23e3b7f511..dfc139a43f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -72,14 +72,14 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 0560f159fc..fb71c260e4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -72,14 +72,14 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp index e94d5bb910..70baf11b5f 100644 --- a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -17,7 +17,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp" namespace ck { namespace profiler { @@ -180,58 +180,27 @@ bool profile_gemm_multi_abd_impl(int do_verification, // run reference if(do_verification) { - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - Tensor c_m_n({M, N}); - using AComputeType = typename std::conditional<(NumATensor > 1), EDataType, remove_cvref_t>>::type; - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(a_element_op, data_refs); - } - } - using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(b_element_op, data_refs); - } - } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -239,21 +208,6 @@ bool profile_gemm_multi_abd_impl(int do_verification, ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - // compulsory - auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); - // optional (if multiple Ds) - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(cde_element_op, data_refs); - } - } } std::array as_device_buf; diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp index 83052acbb3..f0269526e2 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -15,8 +15,8 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/convolution_parameter.hpp" @@ -25,7 +25,6 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/utility/fill.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { namespace profiler { @@ -47,12 +46,12 @@ template -bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, +bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, int init_method, - bool /*do_log*/, + bool do_log, bool time_kernel, const std::vector& Ms, const std::vector& Ns, @@ -60,10 +59,10 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideDs, - int StrideE, - int kbatch = 1, - int n_warmup = 1, - int n_iter = 10) + const std::vector& StrideE, + const std::vector& kbatch_list = {1}, + int n_warmup = 1, + int n_iter = 10) { bool pass = true; @@ -83,15 +82,16 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, std::size_t group_count = Ms.size(); - if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && - group_count == StrideBs.size() && group_count == StrideDs.size())) - { - throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); - } static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size(); + if(group_count != Ns.size() || group_count != Ks.size() || group_count != StrideAs.size() || + group_count != StrideBs.size() || (NumDTensor > 0 && group_count != StrideDs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideAs/Bs/Ds/E size\n"); + } + auto generateInputTupleA = [&](std::size_t g) { if constexpr(NumATensor == 0) { @@ -99,7 +99,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, } else { - using ALayout = remove_cvref_t{}, AsLayout>>; + using ALayout = remove_cvref_t{}, AsLayout>>; return generate_tuple( [&](auto i) { using ADataType = remove_cvref_t>; @@ -116,7 +116,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, } else { - using BLayout = remove_cvref_t{}, BsLayout>>; + using BLayout = remove_cvref_t{}, BsLayout>>; return generate_tuple( [&](auto i) { using BDataType = remove_cvref_t>; @@ -133,7 +133,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, } else { - using DLayout = remove_cvref_t{}, DsLayout>>; + using DLayout = remove_cvref_t{}, DsLayout>>; return generate_tuple( [&](auto i) { using DDataType = remove_cvref_t>; @@ -144,14 +144,14 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, } }; - using InputTupleA = decltype(generateInputTupleA(0)); - using InputTupleB = decltype(generateInputTupleB(0)); - using InputTupleD = decltype(generateInputTupleD(0)); + using AsTensorTuple = decltype(generateInputTupleA(0)); + using BsTensorTuple = decltype(generateInputTupleB(0)); + using DsTensorTuple = decltype(generateInputTupleD(0)); - auto g_as_m_k = reserveVector(group_count); - auto g_bs_k_n = reserveVector(group_count); - auto g_ds_m_n = reserveVector(group_count); - auto g_e_m_n_host_results = reserveVector>(group_count); + auto g_as_m_k = reserveVector(group_count); + auto g_bs_k_n = reserveVector(group_count); + auto g_ds_m_n = reserveVector(group_count); + auto g_e_m_n_host_results = reserveVector>(group_count); auto g_e_m_n_device_results = reserveVector>(group_count); // int sum_of_m = 0; @@ -159,30 +159,30 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, { // sum_of_m += Ms[g]; - auto as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); - auto bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); - auto ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); + auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); + auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); + auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); g_e_m_n_host_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE, ELayout{}))); + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); g_e_m_n_device_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE, ELayout{}))); - - // if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - // { - // std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" - // << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - // << "]:" << c_m_n_device_results[i].mDesc << std::endl; - // } - - // static_for<0, NumATensor, 1>{}( - // [&](auto i) { std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; }); - // static_for<0, NumBTensor, 1>{}( - // [&](auto i) { std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; }); - // static_for<0, NumDTensor, 1>{}( - // [&](auto i) { std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; }); - // std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; - + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << g << std::endl; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; + }); + std::cout << "e_m_n: " << g_e_m_n_device_results[g].mDesc << std::endl; + } + std::size_t num_thread = 1; switch(init_method) { @@ -235,12 +235,13 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, std::vector> g_as_device_view(group_count); std::vector> g_bs_device_view(group_count); std::vector> g_ds_device_view(group_count); - std::vector g_e_device_view(group_count); + std::vector g_e_device_view(group_count); - auto gemm_descs = reserveVector(group_count); + auto g_gemm_descs = reserveVector(group_count); auto grouped_gemm_kernel_args_host = - reserveVector>( + reserveVector>( group_count); for(std::size_t g = 0; g < group_count; g++) @@ -280,34 +281,42 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, auto& ds_device_view = g_ds_device_view[g]; static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - return std::make_unique(sizeof(DDataType) * - ds_m_n(i).mDesc.GetElementSpaceSize()); + using DDataType = remove_cvref_t>; + ds_device_buf[i] = std::make_unique(sizeof(DDataType) * + ds_m_n(i).mDesc.GetElementSpaceSize()); ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); ds_stride[i] = StrideDs[g]; }); - gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ + g_e_device_buf[g] = std::make_unique( + sizeof(EDataType) * g_e_m_n_host_results[g].mDesc.GetElementSpaceSize()); + g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); + + g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ Ms[g], Ns[g], Ks[g], std::vector(as_stride.begin(), as_stride.end()), - std::vector(bs_stride.begin(), as_stride.end()), - std::vector(ds_stride.begin(), as_stride.end()), - StrideE}); - - grouped_gemm_kernel_args_host.push_back({as_device_view, - bs_device_view, - ds_device_view, - g_e_device_buf[g]->GetDeviceBuffer(), - Ms[g], - Ns[g], - Ks[g], - std::move(as_stride), - std::move(bs_stride), - std::move(ds_stride), - StrideE}); + std::vector(bs_stride.begin(), bs_stride.end()), + std::vector(ds_stride.begin(), ds_stride.end()), + StrideE[g]}); + + tensor_operation::device:: + GroupedGemmMultiABDKernelArgument + tmp{as_device_view, + bs_device_view, + ds_device_view, + g_e_device_view[g], + Ms[g], + Ns[g], + Ks[g], + as_stride, + bs_stride, + ds_stride, + StrideE[g]}; + + grouped_gemm_kernel_args_host.push_back(std::move(tmp)); } using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK>{}; - - // if(do_verification) - // { - // for(std::size_t i = 0; i < gemm_descs.size(); i++) - // { - // using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - // auto ref_gemm = ReferenceGemmInstance{}; - // auto ref_invoker = ref_gemm.MakeInvoker(); - - // auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], - // b_k_n[i], - // c_m_n_host_results[i], - // a_element_op, - // b_element_op, - // c_element_op); - - // ref_invoker.Run(ref_argument); - // } - // } + if(do_verification) + { + using AComputeType = + typename std::conditional<(NumATensor > 1), + EDataType, + remove_cvref_t>>::type; + + using BComputeType = + typename std::conditional<(NumBTensor > 1), + EDataType, + remove_cvref_t>>::type; + + for(std::size_t i = 0; i < group_count; i++) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i], + g_bs_k_n[i], + g_ds_m_n[i], + g_e_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + } // profile device GEMM instances for(auto& gemm_ptr : op_ptrs) { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(g_as_device_view, - g_bs_device_view, - g_ds_device_view, - g_e_device_view, - gemm_descs, - a_element_op, - b_element_op, - c_element_op); + auto argument_ptr = gemm_ptr->MakeArgumentPointer(g_as_device_view, + g_bs_device_view, + g_ds_device_view, + g_e_device_view, + g_gemm_descs, + a_element_op, + b_element_op, + c_element_op); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); @@ -396,103 +417,103 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int /*do_verification*/, std::string gemm_name = gemm_ptr->GetTypeString(); - std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; - - if(kbatch > 0) + for(const auto kbatch_curr : kbatch_list) { - kbatch_list = {kbatch}; - } - - for(std::size_t j = 0; j < kbatch_list.size(); j++) - { - - auto kbatch_curr = kbatch_list[j]; - gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - // for(std::size_t i = 0; i < gemm_descs.size(); i++) - // c_device_buf[i]->SetZero(); + for(std::size_t g = 0; g < group_count; g++) + { + g_e_device_buf[g]->SetZero(); + } invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); - // if(do_verification) - // { - // bool instance_pass = true; - // for(std::size_t i = 0; i < gemm_descs.size(); i++) - // { - - // c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); - - // if(std::is_same_v && kbatch_curr > 1) - // { - // instance_pass = - // instance_pass && ck::utils::check_err(c_m_n_device_results[i], - // c_m_n_host_results[i], - // "Error: Incorrect results!", - // 0.06); - // } - // else - // { - // instance_pass = - // instance_pass && ck::utils::check_err(c_m_n_device_results[i], - // c_m_n_host_results[i]); - // } - - // if(do_log) - // { - // LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") - // << std::endl; - // LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") - // << std::endl; - // LogRangeAsType( - // std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") - // << std::endl; - // LogRangeAsType( - // std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") - // << std::endl; - // } - // } - - // std::cout << "Instance: " << gemm_name << " verification " - // << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; - - // pass = pass && instance_pass; - // } - - /*float ave_time =*/ invoker_ptr->Run( + if(do_verification) + { + bool instance_pass = true; + for(std::size_t g = 0; g < group_count; g++) + { + g_e_device_buf[g]->FromDevice(g_e_m_n_device_results[g].mData.data()); + + instance_pass = + instance_pass && ck::utils::check_err(g_e_m_n_device_results[g], + g_e_m_n_host_results[g]); + + if(do_log) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "a[" << g << "] : ", g_as_m_k[g](i).mData, ",") + << std::endl; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "b[" << g << "]: ", g_bs_k_n[g](i).mData, ",") + << std::endl; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "d[" << g << "]: ", g_ds_m_n[g](i).mData, ",") + << std::endl; + }); + LogRangeAsType( + std::cout << "c_device: ", g_e_m_n_device_results[g].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", g_e_m_n_host_results[g].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + float ave_time = invoker_ptr->Run( argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - // if(time_kernel) - // { - // std::size_t flop = 0, num_btype = 0; - // for(std::size_t i = 0; i < gemm_descs.size(); i++) - // { - // flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; - - // num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + - // sizeof(BDataType) * Ks[i] * Ns[i] + - // sizeof(CDataType) * Ms[i] * Ns[i]; - // } - - // float tflops = static_cast(flop) / 1.E9 / ave_time; - - // float gb_per_sec = num_btype / 1.E6 / ave_time; - // std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops - // << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " - // << kbatch_curr << std::endl; - - // if(tflops > best_tflops) - // { - // best_gemm_name = gemm_name; - // best_tflops = tflops; - // best_ave_time = ave_time; - // best_gb_per_sec = gb_per_sec; - // best_kbatch = kbatch_curr; - // } - // } + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t g = 0; g < group_count; g++) + { + flop += std::size_t(2) * Ms[g] * Ns[g] * Ks[g]; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + num_btype += sizeof(ADataType) * Ms[g] * Ks[g]; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + num_btype += sizeof(BDataType) * Ks[g] * Ns[g]; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + num_btype += sizeof(DDataType) * Ms[g] * Ns[g]; + }); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } } else { diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 450950cbd6..bc79c85e59 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,6 +18,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() + + add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp new file mode 100644 index 0000000000..d4ebe229cc --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -0,0 +1,304 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" + +#include "gtest/gtest.h" + +static ck::index_t param_mask = 0xffffff; +static ck::index_t instance_index = -1; + +using FP32 = float; +using FP16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Add = ck::tensor_operation::element_wise::Add; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu> +>; +// clang-format on + +template +class TestGroupedGemmMultiABDFixedNK : public testing::Test +{ + protected: + using AsDataType = std::tuple_element_t<0, Tuple>; + using BsDataType = std::tuple_element_t<1, Tuple>; + using DsDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using AccDataType = float; + using AsLayout = std::tuple_element_t<4, Tuple>; + using BsLayout = std::tuple_element_t<5, Tuple>; + using DsLayout = std::tuple_element_t<6, Tuple>; + using ELayout = std::tuple_element_t<7, Tuple>; + using AElementOp = PassThrough; + using BElementOp = Multiply; + using CDEElementOp = std::tuple_element_t<8, Tuple>; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // integer value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + static constexpr int n_warmup_ = 0; + static constexpr int n_iter_ = 1; + + std::vector k_batches_; + + bool IsSplitKSupported() + { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still use split-K for fp32, but we currently don't have + // instances for it so we disable it entirely + constexpr bool require_16bit_atomic_add = + std::is_same_v || std::is_same_v; + bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported(); + + // CDE element operators are not supported in combination with split K + constexpr bool has_cde_element_operator = !std::is_same_v; + + return !missing_atomic_add && !has_cde_element_operator; + } + + void SetUp() override + { + if(!IsSplitKSupported()) + { + k_batches_ = {1}; + } + else + { + k_batches_ = {1, 2, 3, 5, 8}; + } + } + + private: + template + void SetStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if(std::is_same_v) + { + for(const auto c : cols) + { + strides.emplace_back(c); + } + } + else if(std::is_same_v) + { + for(const auto r : rows) + { + strides.emplace_back(r); + } + } + } + + template + void SetTupleStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if constexpr(Layouts::Size() > 0) + { + // As of now multi ABD implementation supports only tensors with matching layouts. + using Layout = std::remove_cvref_t{}, Layouts>>; + SetStrides(strides, rows, cols); + } + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs = {}, + const std::vector& StrideBs = {}, + const std::vector& StrideDs = {}, + const std::vector& StrideE = {}) + { + std::vector stride_as = StrideAs; + std::vector stride_bs = StrideBs; + std::vector stride_ds = StrideDs; + std::vector stride_e = StrideE; + + if(stride_as.empty()) + { + SetTupleStrides(stride_as, Ms, Ks); + } + if(stride_bs.empty()) + { + SetTupleStrides(stride_bs, Ks, Ns); + } + if(stride_ds.empty()) + { + SetTupleStrides(stride_ds, Ms, Ns); + } + if(stride_e.empty()) + { + SetStrides(stride_e, Ms, Ns); + } + + std::vector k_batches; + for(size_t i = 0; i < k_batches_.size(); i++) + { + if(param_mask & (1 << i)) + { + k_batches.push_back(k_batches_[i]); + } + } + + RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e); + } + + void RunSingle(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideE) + { + bool pass = ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl( + verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideE, + k_batches_, + n_warmup_, + n_iter_); + EXPECT_TRUE(pass); + } +}; + +TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes); + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) +{ +#ifdef CK_USE_XDL + const std::vector Ms{2, 2}; +#else + const std::vector Ms{2, 1}; +#endif + constexpr int N = 256; + constexpr int K = 128; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) +{ +#ifdef CK_USE_XDL + const std::vector Ms{3, 3, 3, 3, 3}; +#else + const std::vector Ms{2, 1, 3, 4, 5}; +#endif + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) +{ +#ifdef CK_USE_XDL + const std::vector Ms{153, 153, 153, 153, 153, 153}; +#else + const std::vector Ms{167, 183, 177, 153, 139, 204}; +#endif + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) +{ +#ifdef CK_USE_XDL + const std::vector Ms{128, 128, 128}; +#else + const std::vector Ms{64, 128, 256}; +#endif + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} From ab6faa7d5ee65e33b6614bfdf4348c03c281e018 Mon Sep 17 00:00:00 2001 From: Zoltan Lakatos Date: Mon, 19 Jan 2026 08:28:01 +0000 Subject: [PATCH 03/10] wmma implementation fixed --- ...device_grouped_gemm_multi_abd_fixed_nk.hpp | 5 + ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 572 +++++++++--------- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 18 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 + .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 2 +- .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 254 ++++++++ .../CMakeLists.txt | 4 +- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance cpp | 111 ---- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 111 ++++ ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 204 +++---- ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 128 ++-- ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 19 + ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 17 + ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 13 +- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 106 ++-- .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 24 +- 16 files changed, 923 insertions(+), 666 deletions(-) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp index 9c185923ca..bb889e8026 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp @@ -5,6 +5,7 @@ #include #include +#include #include "device_grouped_gemm_multi_abd.hpp" @@ -71,6 +72,10 @@ struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD { + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 546256ac79..aec05c38fd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -32,8 +32,17 @@ namespace device { template __global__ void @@ -41,78 +50,124 @@ __global__ void __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { -#if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>()]; + + const index_t KBatch = 1; const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - // Binary search lookup to find which group this block is part of - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && - block_id < gemm_desc_ptr[group_id].block_end_)) && - left <= right) - { - if(block_id < gemm_desc_ptr[group_id].block_start_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); - } + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; - // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index - // and thus needs a non-const reference. It's also not feasible to store this in global - // memory as different threads would be writing different K values to the same arg struct - auto karg = gemm_desc_ptr[group_id].karg_; + auto karg = gemm_desc_ptr[group_id]; + if(karg.M == 0 || karg.N == 0 || karg.K == 0) + return; + + // using e_data_type = remove_cvref_t>; #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) #endif - const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; - - // Tile index first dimension is the K batch - auto tile_index = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - auto splitk_batch_offset = - typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run(static_cast(p_shared), - splitk_batch_offset, - karg, - block_2_ctile_map, - epilogue_args); -#if defined(__gfx11__) + { + + typename GridwiseGemm::Problem problem(karg.M, + karg.N, + karg.K, + karg.StrideAs, + karg.StrideBs, + karg.StrideDs, + karg.StrideE, + KBatch); + + const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(karg.p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + static_cast(karg.p_e_grid), + p_shared, + problem, + block_2_etile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } } -#endif #else ignore = gemm_descs_const; ignore = group_count; -#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) + ignore = grid_size_grp; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif } template // ??? + typename ComputeTypeA = EDataType, + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK : public DeviceGroupedGemmMultiABDFixedNK, + typename uniform_sequence_gen::type, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -244,11 +300,6 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK false, false>; - using CGridDesc_M_N = - remove_cvref_t( - 1, 1, 1, 1, 1))>; - - // Move OffsettedBlockToCTileMapMLoops and BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops to helper hpp? template struct OffsettedBlockToCTileMapMLoops { @@ -412,7 +463,8 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK { } }; - using GemmTransKernelArg = GemmTransKernelArgBase; + using GemmTransKernelArg = + GroupedGemmMultiABDKernelArgument; static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) { @@ -467,7 +519,8 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK if(!(group_count_ == ck::type_convert(p_As.size()) && group_count_ == ck::type_convert(p_Bs.size()) && - ((NumDTensor == 0 && p_Ds.size() == 0) || group_count_ == ck::type_convert(p_Ds.size())) && + ((NumDTensor == 0 && p_Ds.size() == 0) || + group_count_ == ck::type_convert(p_Ds.size())) && group_count_ == ck::type_convert(p_Es.size()))) { throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); @@ -475,21 +528,25 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK gemm_desc_kernel_arg_.reserve(group_count_); + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); const index_t fixed_N = gemm_descs[0].N_; const index_t fixed_K = gemm_descs[0].K_; - for(std::size_t i = 0; i < gemm_descs.size(); i++) + for(std::size_t g = 0; g < gemm_descs.size(); g++) { - const index_t M = gemm_descs[i].M_; - const index_t N = gemm_descs[i].N_; - const index_t K = gemm_descs[i].K_; + const index_t M = gemm_descs[g].M_; + const index_t N = gemm_descs[g].N_; + const index_t K = gemm_descs[g].K_; - if(N != fixed_N || K != fixed_K) // M? + if(M != sum_of_m || N != fixed_N || K != fixed_K) { - throw std::runtime_error("wrong! N or K are not fixed across GEMM groups"); + throw std::runtime_error("wrong! M/N/K is not identical"); } - a_mtx_mraw_kraw_.emplace_back(M, K); + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); b_mtx_nraw_kraw_.emplace_back(N, K); // pointer @@ -497,100 +554,93 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK std::array p_bs_grid; std::array p_ds_grid; - static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; }); - static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; }); - static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + static_for<0, NumATensor, 1>{}([&](auto i) { p_as_grid[i] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { p_bs_grid[i] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid[i] = nullptr; }); std::array StrideAs; std::array StrideBs; std::array StrideDs; - const index_t StrideE = gemm_descs[i].stride_C_; + const index_t StrideE = gemm_descs[g].stride_C_; - if(gemm_descs[i].stride_As_.size() != NumATensor) + if(gemm_descs[g].stride_As_.size() != NumATensor) { throw std::runtime_error( "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); } - + static_for<0, NumATensor, 1>{}( - [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; }); + [&](auto j) { StrideAs[j] = gemm_descs[g].stride_As_[j]; }); - if(gemm_descs[i].stride_Bs_.size() != NumBTensor) + if(gemm_descs[g].stride_Bs_.size() != NumBTensor) { throw std::runtime_error( "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); } static_for<0, NumBTensor, 1>{}( - [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; }); - - if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + [&](auto j) { StrideBs[j] = gemm_descs[g].stride_Bs_[j]; }); + + if(gemm_descs[g].stride_Ds_.size() != NumDTensor) { throw std::runtime_error( "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); } static_for<0, NumDTensor, 1>{}( - [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); - - const index_t m_padded = GridwiseGemm::CalculateMPadded(M); - const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + [&](auto j) { StrideDs[j] = gemm_descs[g].stride_Ds_[j]; }); const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( - M, m_padded, N, n_padded, StrideE); + AverM, AverM, N, N, StrideE); // block-to-e-tile map const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + if(group_id * grid_size_grp_ != grid_size_) { - throw std::runtime_error("wrong! block_2_etile_map validation failed"); + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); } const index_t block_start = grid_size_; - const index_t block_end = grid_size_ + grid_size_grp_; grid_size_ += grid_size_grp_; + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); - auto karg = KernelArgument(p_as_grid, - p_bs_grid, - p_ds_grid, - type_convert(p_Es[i]), - M, - N, - K, - StrideAs, - StrideBs, - StrideDs, - StrideE, - k_batch_, - a_element_op, - b_element_op, - c_element_op, - false); - - gemm_desc_kernel_arg_.emplace_back( - std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); - - // group_id++; + auto karg = GemmTransKernelArg({p_as_grid, + p_bs_grid, + p_ds_grid, + type_convert(p_Es[g]), + AverM, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE}); + + gemm_desc_kernel_arg_.emplace_back(std::move(karg)); + + group_id++; } const auto e_grid_desc_sum_m_n = - GridwiseGemm::template MakeDEGridDescriptor_M_N( - group_count_ * gemm_descs[0].M_, - group_count_ * gemm_descs[0].M_, - gemm_descs[0].N_, - gemm_descs[0].N_, - gemm_descs[0].stride_C_); + GridwiseGemm::template MakeDEGridDescriptor_M_N(sum_of_m, + sum_of_m, + gemm_descs[0].N_, + gemm_descs[0].N_, + gemm_descs[0].stride_C_); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); } @@ -613,186 +663,96 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK index_t grid_size_; index_t grid_size_grp_; index_t barrier_size_grp_; - index_t k_batch_; + index_t sum_of_m; + + index_t k_batch_ = 1; }; // Invoker struct Invoker : public BaseInvoker { - float RunImp(const Argument& arg, - const StreamConfig& stream_config = StreamConfig{}, - hipStream_t cpy_stream = nullptr, - hipEvent_t cpy_event = nullptr) - { - using GemmTransKernelArg_ = GemmTransKernelArgBase; - static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); - - bool all_have_kbatch_gt_one = arg.gemm_desc_kernel_arg_[0].karg_.KBatch > 1; - bool all_have_main_k0_block_loop = - CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[0].karg_); + using Argument = DeviceOp::Argument; - bool not_all_have_main_k0_block_loop_same = false; - bool not_all_have_kbatch_value_same = false; - - for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) - { - const auto& karg = reinterpret_cast( - arg.gemm_desc_kernel_arg_[i].karg_); - - if(stream_config.log_level_ > 0) - { - karg.Print(); - } - - auto kbatch = karg.KBatch; - - if(!GridwiseGemm::CheckValidity(karg)) - { - std::ostringstream err; - err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - - not_all_have_main_k0_block_loop_same |= - all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); - not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); - } - - if(not_all_have_main_k0_block_loop_same) + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + // for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + // { + // if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].karg_.K_) + // != + // has_main_k_block_loop) + // { + // throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + // } + // } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) { - std::ostringstream err; - err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - // throw std::runtime_error(err.str()); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); } - if(not_all_have_kbatch_value_same) - { - std::ostringstream err; - err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + float ave_time = 0; - // If the user provides copy stream and copy event, we assume that they're also - // responsible for providing allocated host memory (eg. pinned) which - // would be used to copy kernel arguments to the device. - if(cpy_stream && cpy_event) - { - if(arg.gemm_kernel_host_args_ == nullptr) - { - std::ostringstream err; - err << "No memory has been allocated for gemm kernel host args " - << "when providing the copy stream and copy event! In " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - hip_check_error(hipMemcpyAsync(arg.p_workspace_, - arg.gemm_kernel_host_args_, - arg.group_count_ * sizeof(GemmTransKernelArg_), - hipMemcpyHostToDevice, - cpy_stream)); + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = kernel_grouped_gemm_wmma_fixed_nk; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; - hip_check_error(hipEventRecord(cpy_event, cpy_stream)); + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; - hip_check_error(hipEventSynchronize(cpy_event)); - } - else // In this case CK owns memory allocated on host. + if(arg.k_batch_ > 1) { - - hip_check_error( - hipMemcpyAsync(arg.p_workspace_, - arg.gemm_desc_kernel_arg_.data(), - arg.gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg_), - hipMemcpyHostToDevice, - stream_config.stream_id_)); - } - - float ave_time = 0; - - const auto Run = [&](const auto& kernel) { - if(all_have_kbatch_gt_one) + if(has_main_k_block_loop) { - for(const auto& trans_arg : arg.gemm_desc_kernel_arg_) - { - - const auto& karg = trans_arg.karg_; - hip_check_error(hipMemsetAsync(karg.p_e_grid, - 0, - karg.M * karg.N * sizeof(EDataType), - stream_config.stream_id_)); - } + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); } - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.p_workspace_), - arg.gemm_desc_kernel_arg_.size()); - }; - - // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all - if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + else { - if(all_have_kbatch_gt_one) - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - - Run(kernel); - } + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); } } else { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + if(has_main_k_block_loop) { - if(all_have_kbatch_gt_one) - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - - Run(kernel); - } + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); } } @@ -809,6 +769,11 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) { return false; @@ -830,8 +795,19 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); - supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); - supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + bool isABlockTransferValid = (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + if(!isABlockTransferValid && ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Invalid block transfer for A block.\n"); + } + + bool isBBlockTransferValid = (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + if(!isBBlockTransferValid && ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Invalid block transfer for B block.\n"); + } + + supported &= isABlockTransferValid && isBBlockTransferValid; } } @@ -946,6 +922,14 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); } + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * + sizeof(GroupedGemmMultiABDKernelArgument); + } + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { auto p_arg_ = dynamic_cast(p_arg); @@ -958,21 +942,23 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); } - size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override { - auto arg = *dynamic_cast(p_arg); + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; - return arg.group_count_ * sizeof(GroupedGemmMultiABDKernelArgument); + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); } - size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } - static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } // polymorphic - void SetKBatch(BaseArgument* /*p_arg*/, index_t /*kbatch*/) const override + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override { - throw std::runtime_error("??? figure out later"); + return SetKBatch(*dynamic_cast(p_arg), k_batch); } void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 897773f768..67e4148177 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -688,6 +688,11 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + // Split-K autodeduction is not supported if(arg.k_batch_ < 1) { @@ -729,19 +734,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector gemm_descs, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) - { - return Argument{ - p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; - } - static auto MakeInvoker() { return Invoker{}; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index a1cba118b2..f1595d8f15 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 using typename Base::DsGridPointer; using AsDataType_ = AsDataType; using BsDataType_ = BsDataType; + using EDataType_ = EDataType; struct Problem { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index b7b88d4920..e66ede1afe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -400,7 +400,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); } - __host__ __device__ static auto CalculateMPadded(index_t M) + __host__ __device__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 2d75f20670..7820549ad8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -183,6 +183,98 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan #endif #if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// RCR +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( std::vector, ck::Tuple, @@ -195,6 +287,59 @@ void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_insta PassThrough, Multiply, PassThrough>>>& instances); + +// CRR +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); #endif // CK_USE @@ -267,6 +412,37 @@ struct DeviceOperationInstanceFactory< } #endif // CK_USE_XDL +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA + return op_ptrs; } }; @@ -340,6 +516,37 @@ struct DeviceOperationInstanceFactory< } #endif // CK_USE_XDL +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA + return op_ptrs; } }; @@ -413,6 +620,37 @@ struct DeviceOperationInstanceFactory< } #endif // CK_USE_XDL +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA + return op_ptrs; } }; @@ -491,6 +729,22 @@ struct DeviceOperationInstanceFactory< is_same_v> && is_same_v> && is_same_v) { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + op_ptrs); + } + if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index 631020dc2d..fc60f48727 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -9,9 +9,9 @@ list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp - # device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp - # device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp deleted file mode 100644 index d483a828a3..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance cpp +++ /dev/null @@ -1,111 +0,0 @@ -// // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// // SPDX-License-Identifier: MIT - -// #include - -// #include "ck/ck.hpp" -// #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -// #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -// #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" -// #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -// #include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp" - -// namespace ck { -// namespace tensor_operation { -// namespace device { -// namespace instance { - -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple, -// EDataType, -// AElementOp, -// BElementOp, -// AddFastGelu>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< -// ck::Tuple, -// ck::Tuple, -// AddFastGelu, -// GemmMNKPadding>{}); -// } - -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple, -// EDataType, -// AElementOp, -// BElementOp, -// Add>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< -// ck::Tuple, -// ck::Tuple, -// Add, -// GemmMNKPadding>{}); -// } - -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple<>, -// EDataType, -// AElementOp, -// BElementOp, -// PassThrough>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< -// ck::Tuple<>, -// ck::Tuple<>, -// PassThrough, -// GemmMNKPadding>{}); -// } - -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple<>, -// EDataType, -// AElementOp, -// BElementOp, -// FastGelu>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< -// ck::Tuple<>, -// ck::Tuple<>, -// FastGelu, -// GemmMNKPadding>{}); -// } - -// } // namespace instance -// } // namespace device -// } // namespace tensor_operation -// } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..b53e53fee8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple, + ck::Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple, + ck::Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp index a4b0d776c7..203c545121 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -1,111 +1,111 @@ -// // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// // SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT -// #include +#include -// #include "ck/ck.hpp" -// #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -// #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -// #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -// #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" -// #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -// #include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp" +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp" -// namespace ck { -// namespace tensor_operation { -// namespace device { -// namespace instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple, -// EDataType, -// AElementOp, -// BElementOp, -// AddFastGelu>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< -// ck::Tuple, -// ck::Tuple, -// AddFastGelu, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple, + ck::Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple, -// EDataType, -// AElementOp, -// BElementOp, -// Add>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< -// ck::Tuple, -// ck::Tuple, -// Add, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple, + ck::Tuple, + Add, + GemmMNKPadding>{}); +} -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple<>, -// EDataType, -// AElementOp, -// BElementOp, -// PassThrough>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< -// ck::Tuple<>, -// ck::Tuple<>, -// PassThrough, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} -// void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple<>, -// EDataType, -// AElementOp, -// BElementOp, -// FastGelu>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< -// ck::Tuple<>, -// ck::Tuple<>, -// FastGelu, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} -// } // namespace instance -// } // namespace device -// } // namespace tensor_operation -// } // namespace ck +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp index 80c13f4ef2..dac83dd6c7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -10,56 +10,56 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp" // TODO: find a final spot for instances. +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -// void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple, -// EDataType, -// AElementOp, -// BElementOp, -// AddFastGelu>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< -// ck::Tuple, -// ck::Tuple, -// AddFastGelu, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple, + ck::Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} -// void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple, -// EDataType, -// AElementOp, -// BElementOp, -// Add>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< -// ck::Tuple, -// ck::Tuple, -// Add, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple, + ck::Tuple, + Add, + GemmMNKPadding>{}); +} void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( std::vector{}); } -// void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( -// std::vector, -// ELayout, -// AsDataType, -// BsDataType, -// ck::Tuple<>, -// EDataType, -// AElementOp, -// BElementOp, -// FastGelu>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< -// ck::Tuple<>, -// ck::Tuple<>, -// FastGelu, -// GemmMNKPadding>{}); -// } +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index dfc139a43f..b3b6c0b058 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -83,6 +84,24 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = + std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index fb71c260e4..65f87ca5c6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -83,6 +84,22 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 2843918693..398d930bc6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -91,12 +91,13 @@ template using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< // clang-format off - //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> -// DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< ck::Tuple, ck::Tuple, ck::Tuple, Row, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple, BF16, PassThrough, PassThrough, Add, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> + //######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> // clang-format on >; diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp index f0269526e2..625a8cc311 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "ck/ck.hpp" #include "ck/utility/env.hpp" @@ -80,7 +81,8 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, } }; - std::size_t group_count = Ms.size(); + const std::size_t group_count = Ms.size(); + const int sum_of_m = std::accumulate(Ms.begin(), Ms.end(), 0); static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); @@ -153,12 +155,9 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, auto g_ds_m_n = reserveVector(group_count); auto g_e_m_n_host_results = reserveVector>(group_count); auto g_e_m_n_device_results = reserveVector>(group_count); - // int sum_of_m = 0; for(std::size_t g = 0; g < group_count; g++) { - // sum_of_m += Ms[g]; - auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); @@ -222,9 +221,9 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, } } - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CDEElementOp{}; + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; using DeviceMemPtr = std::unique_ptr; std::vector> g_as_device_buf(group_count); @@ -256,9 +255,9 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, static_for<0, NumATensor, 1>{}([&](auto i) { using ADataType = remove_cvref_t>; - as_device_buf[i] = std::make_unique(sizeof(ADataType) * - as_m_k(i).mDesc.GetElementSpaceSize()); - as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); + as_device_buf[i] = std::make_unique(sizeof(ADataType) * sum_of_m * Ks[g]); + as_device_buf[i]->ToDevice(as_m_k[i].mData.data(), + as_m_k[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); as_stride[i] = StrideAs[g]; }); @@ -282,19 +281,18 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; - ds_device_buf[i] = std::make_unique(sizeof(DDataType) * - ds_m_n(i).mDesc.GetElementSpaceSize()); - ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); + ds_device_buf[i] = std::make_unique(sizeof(DDataType) * sum_of_m * Ns[g]); + ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data(), + ds_m_n[i].mDesc.GetElementSpaceSize() * sizeof(DDataType)); ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); ds_stride[i] = StrideDs[g]; }); - g_e_device_buf[g] = std::make_unique( - sizeof(EDataType) * g_e_m_n_host_results[g].mDesc.GetElementSpaceSize()); + g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * sum_of_m * Ns[g]); g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ - Ms[g], + sum_of_m, Ns[g], Ks[g], std::vector(as_stride.begin(), as_stride.end()), @@ -304,19 +302,19 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, tensor_operation::device:: GroupedGemmMultiABDKernelArgument - tmp{as_device_view, - bs_device_view, - ds_device_view, - g_e_device_view[g], - Ms[g], - Ns[g], - Ks[g], - as_stride, - bs_stride, - ds_stride, - StrideE[g]}; - - grouped_gemm_kernel_args_host.push_back(std::move(tmp)); + kernelArg{as_device_view, + bs_device_view, + ds_device_view, + g_e_device_view[g], + Ms[g], + Ns[g], + Ks[g], + as_stride, + bs_stride, + ds_stride, + StrideE[g]}; + + grouped_gemm_kernel_args_host.push_back(std::move(kernelArg)); } using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNKMakeArgumentPointer(g_as_device_view, - g_bs_device_view, - g_ds_device_view, - g_e_device_view, - g_gemm_descs, - a_element_op, - b_element_op, - c_element_op); + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs); - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + if (!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Gemm incompatible with runtime set parameters. Skipping..." << std::endl; + } - DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + continue; + } - DeviceMem grouped_gemm_kernel_args_dev( - gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer()); + DeviceMem grouped_gemm_kernel_args_dev(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), grouped_gemm_kernel_args_host.data(), gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), hipMemcpyHostToDevice)); - gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op); - gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), - grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); std::string gemm_name = gemm_ptr->GetTypeString(); @@ -428,15 +427,17 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, g_e_device_buf[g]->SetZero(); } - invoker_ptr->Run(argument_ptr.get(), - StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); if(do_verification) { bool instance_pass = true; for(std::size_t g = 0; g < group_count; g++) { - g_e_device_buf[g]->FromDevice(g_e_m_n_device_results[g].mData.data()); + g_e_device_buf[g]->FromDevice( + g_e_m_n_device_results[g].mData.data(), + g_e_m_n_device_results[g].mDesc.GetElementSize() * sizeof(EDataType)); instance_pass = instance_pass && ck::utils::check_err(g_e_m_n_device_results[g], @@ -446,7 +447,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, { static_for<0, NumATensor, 1>{}([&](auto i) { LogRangeAsType( - std::cout << "a[" << g << "] : ", g_as_m_k[g](i).mData, ",") + std::cout << "a[" << g << "]: ", g_as_m_k[g](i).mData, ",") << std::endl; }); static_for<0, NumBTensor, 1>{}([&](auto i) { @@ -460,10 +461,10 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, << std::endl; }); LogRangeAsType( - std::cout << "c_device: ", g_e_m_n_device_results[g].mData, ",") + std::cout << "e_device: ", g_e_m_n_device_results[g].mData, ",") << std::endl; LogRangeAsType( - std::cout << "c_host : ", g_e_m_n_host_results[g].mData, ",") + std::cout << "e_host : ", g_e_m_n_host_results[g].mData, ",") << std::endl; } } @@ -474,9 +475,6 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, pass = pass && instance_pass; } - float ave_time = invoker_ptr->Run( - argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - if(time_kernel) { std::size_t flop = 0, num_btype = 0; diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp index d4ebe229cc..c5a55bac22 100644 --- a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -224,13 +224,9 @@ TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes); TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) { -#ifdef CK_USE_XDL - const std::vector Ms{2, 2}; -#else - const std::vector Ms{2, 1}; -#endif - constexpr int N = 256; - constexpr int K = 128; + const std::vector Ms{3, 4}; + constexpr int N = 8; + constexpr int K = 64; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); @@ -240,11 +236,7 @@ TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) { -#ifdef CK_USE_XDL - const std::vector Ms{3, 3, 3, 3, 3}; -#else - const std::vector Ms{2, 1, 3, 4, 5}; -#endif + const std::vector Ms{3, 5, 16, 7, 8}; constexpr int N = 768; constexpr int K = 544; @@ -256,11 +248,7 @@ TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) { -#ifdef CK_USE_XDL - const std::vector Ms{153, 153, 153, 153, 153, 153}; -#else const std::vector Ms{167, 183, 177, 153, 139, 204}; -#endif constexpr int N = 768; constexpr int K = 544; @@ -272,11 +260,7 @@ TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) { -#ifdef CK_USE_XDL - const std::vector Ms{128, 128, 128}; -#else const std::vector Ms{64, 128, 256}; -#endif constexpr int N = 768; constexpr int K = 320; From cf0ebfb098d017b0d9ebf8dc4d47ef20d2db76bd Mon Sep 17 00:00:00 2001 From: Zoltan Lakatos Date: Mon, 19 Jan 2026 14:20:26 +0000 Subject: [PATCH 04/10] avoid unnecessary device mem allocation and code cleanups --- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 72 ++---- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 15 +- ...rofile_grouped_gemm_multi_abd_fixed_nk.cpp | 206 ------------------ 3 files changed, 21 insertions(+), 272 deletions(-) delete mode 100644 profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index aec05c38fd..ccc443df83 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -28,7 +28,6 @@ namespace ck { namespace tensor_operation { namespace device { -// Can be shared between multiple device implementations.... template ; + // TODO: Block to tile mappings could potentially moved out to avoid code duplications between + // different device implementations. + template struct OffsettedBlockToCTileMapMLoops { @@ -444,25 +446,6 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK static constexpr index_t DefaultKBatch = 1; using KernelArgument = typename GridwiseGemm::Argument; - template - struct GemmTransKernelArgBase - { - KernelArgument_ karg_; - GroupedGemmBlock2ETileMap block_2_ctile_map_; - index_t block_start_, block_end_; - - GemmTransKernelArgBase() = default; - GemmTransKernelArgBase(KernelArgument_&& karg, - GroupedGemmBlock2ETileMap&& b2c_map, - index_t block_start, - index_t block_end) - : karg_{karg}, - block_2_ctile_map_{b2c_map}, - block_start_{block_start}, - block_end_{block_end} - { - } - }; using GemmTransKernelArg = GroupedGemmMultiABDKernelArgument; @@ -647,7 +630,6 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK void UpdateKBatch(index_t) {} - // private: index_t group_count_; AElementwiseOperation a_element_op_; @@ -675,18 +657,6 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - bool has_main_k_block_loop = true; - - // for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) - // { - // if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].karg_.K_) - // != - // has_main_k_block_loop) - // { - // throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); - // } - // } - if(arg.grouped_gemm_kernel_args_dev == nullptr) { throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); @@ -694,10 +664,10 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK float ave_time = 0; - auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + auto launch_kernel = [&](auto e_global_memory_operation_) { const auto kernel = kernel_grouped_gemm_wmma_fixed_nk 1) { - if(has_main_k_block_loop) - { - ave_time = - launch_kernel(integral_constant{}, - integral_constant{}); - } - else - { - ave_time = - launch_kernel(integral_constant{}, - integral_constant{}); - } + ave_time = launch_kernel(integral_constant{}); } else { - if(has_main_k_block_loop) - { - ave_time = launch_kernel(integral_constant{}, - integral_constant{}); - } - else - { - ave_time = launch_kernel(integral_constant{}, - integral_constant{}); - } + ave_time = launch_kernel(integral_constant{}); } return ave_time; @@ -811,6 +761,14 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK } } + for(index_t i = 0; i < arg.group_count_; i++) + { + if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != true) + { + supported = false; + } + } + // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd // instruction that supports bf16 and we cannot use splitk because of that if constexpr(std::is_same::value) diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp index 625a8cc311..304e523ca3 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -255,9 +255,8 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, static_for<0, NumATensor, 1>{}([&](auto i) { using ADataType = remove_cvref_t>; - as_device_buf[i] = std::make_unique(sizeof(ADataType) * sum_of_m * Ks[g]); - as_device_buf[i]->ToDevice(as_m_k[i].mData.data(), - as_m_k[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + as_device_buf[i] = std::make_unique(sizeof(ADataType) * Ms[g] * Ks[g]); + as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); as_stride[i] = StrideAs[g]; }); @@ -268,8 +267,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, static_for<0, NumBTensor, 1>{}([&](auto i) { using BDataType = remove_cvref_t>; - bs_device_buf[i] = std::make_unique(sizeof(BDataType) * - bs_k_n(i).mDesc.GetElementSpaceSize()); + bs_device_buf[i] = std::make_unique(sizeof(BDataType) * Ks[g] * Ns[g]); bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data()); bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer(); bs_stride[i] = StrideBs[g]; @@ -281,14 +279,13 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; - ds_device_buf[i] = std::make_unique(sizeof(DDataType) * sum_of_m * Ns[g]); - ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data(), - ds_m_n[i].mDesc.GetElementSpaceSize() * sizeof(DDataType)); + ds_device_buf[i] = std::make_unique(sizeof(DDataType) * Ms[g] * Ns[g]); + ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); ds_stride[i] = StrideDs[g]; }); - g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * sum_of_m * Ns[g]); + g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ diff --git a/profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp b/profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp deleted file mode 100644 index 16282a3e7a..0000000000 --- a/profiler/src/profile_grouped_gemm_multi_abd_fixed_nk.cpp +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" -#include "profiler_operation_registry.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN // 2 -}; - -enum struct GemmDataType -{ - BF16_I8_BF16 // 0 -}; - -#define OP_NAME "grouped_gemm_multi_abd_fixed_nk" -#define OP_DESC "Grouped GEMM Multi ABD Fixed NK" - -namespace { - -std::vector argToIntArray(char* input) -{ - std::vector out; - std::istringstream in(input); - std::string item; - - while(std::getline(in, item, ',')) - { - out.push_back(std::stoi(item)); - } - return out; -} - -int profile_grouped_gemm_multi_abd_fixed_nk(int argc, char* argv[]) -{ - if(argc < 14) - { - std::cout - << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: bf16@int8; 1: fp16; 2: fp16@fp8; 3: fp16@int8)\n" - << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" - << " 1: A[m, k] * B[n, k] = C[m, n];\n" - << " 2: A[k, m] * B[k, n] = C[m, n];)\n" - << "arg4: verification (0: no; 1: yes)\n" - << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" - << "arg6: print tensor value (0: no; 1: yes)\n" - << "arg7: time kernel (0=n0, 1=yes)\n" - << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " - "64,64 64,64 128,128)\n" - << "arg15: kbatch value (default 1)\n" - << "optional:\n" - << "arg16: number of warm-up cycles (default 1)\n" - << "arg17: number of iterations (default 10)\n" - << std::endl; - - exit(1); - } - - // const auto data_type = static_cast(std::stoi(argv[2])); - // const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - - const auto Ms = argToIntArray(argv[8]); - const auto Ns = argToIntArray(argv[9]); - const auto Ks = argToIntArray(argv[10]); - - const auto StrideAs = argToIntArray(argv[11]); - const auto StrideBs = argToIntArray(argv[12]); - const auto StrideDs = argToIntArray(argv[13]); - const auto StrideE = StrideDs.at(0); - const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; - - int n_warmup = 1; - int n_iter = 10; - if(argc == 17) - { - n_warmup = std::stoi(argv[15]); - n_iter = std::stoi(argv[16]); - } - - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, - ck::Tuple, - ck::Tuple<>, - float, - float, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - Row>(do_verification, - init_method, - do_log, - time_kernel, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideDs, - StrideE, - kbatch, - n_warmup, - n_iter); - -// #if defined(CK_ENABLE_INT8) -// #if defined(CK_ENABLE_BF16) -// if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::KM_KN_MN) -// { -// ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, -// ck::Tuple, -// ck::Tuple<>, -// float, -// float, -// ck::Tuple, -// ck::Tuple, -// ck::Tuple<>, -// Row>(do_verification, -// init_method, -// do_log, -// time_kernel, -// Ms, -// Ns, -// Ks, -// StrideAs, -// StrideBs, -// StrideDs, -// StrideE, -// kbatch, -// n_warmup, -// n_iter); -// } -// else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) -// { -// ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, -// ck::Tuple, -// ck::Tuple, -// float, -// float, -// ck::Tuple, -// ck::Tuple, -// ck::Tuple, -// Row>(do_verification, -// init_method, -// do_log, -// time_kernel, -// Ms, -// Ns, -// Ks, -// StrideAs, -// StrideBs, -// StrideDs, -// StrideE, -// kbatch, -// n_warmup, -// n_iter); -// } -// else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) -// { -// ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl, -// ck::Tuple, -// ck::Tuple, -// float, -// float, -// ck::Tuple, -// ck::Tuple, -// ck::Tuple, -// Row>(do_verification, -// init_method, -// do_log, -// time_kernel, -// Ms, -// Ns, -// Ks, -// StrideAs, -// StrideBs, -// StrideDs, -// StrideE, -// kbatch, -// n_warmup, -// n_iter); -// } -// #endif // CK_ENABLE_BF16 -// #endif // CK_ENABLE_INT8 -// else -// { -// throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); -// } - return 0; -} - -} // anonymous namespace - -REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_multi_abd_fixed_nk); From 97db9d6565ff21ff691a72a3e8d6b7acf2d395d1 Mon Sep 17 00:00:00 2001 From: Zoltan Lakatos Date: Mon, 19 Jan 2026 15:45:25 +0000 Subject: [PATCH 05/10] cleanup instances definitions --- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 135 +++++++++++------- ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 135 +++++++++++------- ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 133 ++++++++++------- ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 17 --- ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 15 -- ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 16 --- 6 files changed, 251 insertions(+), 200 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp index b53e53fee8..f24f357bde 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp @@ -7,100 +7,133 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, AddFastGelu>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - ck::Tuple, - ck::Tuple, + Tuple, + Tuple, AddFastGelu, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, Add>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - ck::Tuple, - ck::Tuple, + Tuple, + Tuple, Add, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple<>, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, PassThrough>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - ck::Tuple<>, - ck::Tuple<>, + Tuple<>, + Tuple<>, PassThrough, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple<>, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, FastGelu>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - ck::Tuple<>, - ck::Tuple<>, + Tuple<>, + Tuple<>, FastGelu, GemmMNKPadding>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp index 203c545121..ef18fca156 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -7,100 +7,133 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, AddFastGelu>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - ck::Tuple, - ck::Tuple, + Tuple, + Tuple, AddFastGelu, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, Add>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - ck::Tuple, - ck::Tuple, + Tuple, + Tuple, Add, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple<>, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, PassThrough>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - ck::Tuple<>, - ck::Tuple<>, + Tuple<>, + Tuple<>, PassThrough, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple<>, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, FastGelu>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - ck::Tuple<>, - ck::Tuple<>, + Tuple<>, + Tuple<>, FastGelu, GemmMNKPadding>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp index dac83dd6c7..e3139214fe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -9,98 +9,131 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> + // clang-format on + >; + void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, AddFastGelu>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - ck::Tuple, - ck::Tuple, + Tuple, + Tuple, AddFastGelu, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, Add>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - ck::Tuple, - ck::Tuple, + Tuple, + Tuple, Add, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple<>, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, PassThrough>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - ck::Tuple<>, - ck::Tuple<>, + Tuple<>, + Tuple<>, PassThrough, GemmMNKPadding>{}); } void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - std::vector, - ELayout, - AsDataType, - BsDataType, - ck::Tuple<>, - EDataType, - AElementOp, - BElementOp, + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, FastGelu>>>& instances) { add_device_operation_instances( instances, device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - ck::Tuple<>, - ck::Tuple<>, + Tuple<>, + Tuple<>, FastGelu, GemmMNKPadding>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index b3b6c0b058..a9ce988322 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -85,23 +85,6 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances // clang-format on >; -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = - std::tuple< - // clang-format off - //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 65f87ca5c6..7bbfe30a4a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -85,21 +85,6 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances // clang-format on >; -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 398d930bc6..df2468a2e1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -85,22 +85,6 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances // clang-format on >; -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< - // clang-format off - //######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | - //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> - // clang-format on - >; - } // namespace instance } // namespace device } // namespace tensor_operation From 9a10230aedeeb10b1f7b4caf90ba8ca1a62e8659 Mon Sep 17 00:00:00 2001 From: Zoltan Lakatos Date: Tue, 20 Jan 2026 13:56:59 +0000 Subject: [PATCH 06/10] wmma examples added --- .../59_grouped_gemm_multi_ABD/CMakeLists.txt | 8 + ...m_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | 407 ++++++++++++++++++ ...gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 401 +++++++++++++++++ ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 22 +- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 13 + ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 30 +- profiler/src/CMakeLists.txt | 20 - 7 files changed, 851 insertions(+), 50 deletions(-) create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index 4155e0a344..d7ff58705c 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -8,3 +8,11 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) + +add_custom_target(example_grouped_gemm_wmma_multi_abd) + +add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp) +add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8) \ No newline at end of file diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp new file mode 100644 index 0000000000..5aaff10f13 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp @@ -0,0 +1,407 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; + + std::vector> a0_tensors; + std::vector> b_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a0_tensors.reserve(group_count); + b_tensors.reserve(group_count); + b0_tensors.reserve(group_count); + b1_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b1_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); + + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + + sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), + b0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), + b1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B1DataType)); + + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(&argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); + } + } + + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + c_host_tensors[i], + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000..f4043936c3 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp @@ -0,0 +1,401 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/utility/scheduler_enum.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; +using Scale = ck::tensor_operation::element_wise::Scale; +using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; + +using A0DataType = F16; +using A1DataType = F32; +using AsDataType = ck::Tuple; +using B0DataType = F16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using A1Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK +// clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; + + std::vector> a0_tensors; + std::vector> a1_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a0_tensors.reserve(group_count); + a1_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + a1_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + a1_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 2; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + a1_tensors_device.emplace_back( + std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), + a1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A1DataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + {1, 1}, + {problem_size.stride_Bs[i]}, + {0}, + 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer(), + a1_tensors_device[i]->GetDeviceBuffer()}, + std::array{b_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i], + problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i]}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + constexpr float scale = 1.f; + auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(&argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); + } + } + + c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), + e_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + e_host_tensors[i], + PassThrough{}, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(64); + problem_size.Ks.push_back(64); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index ccc443df83..62f47758d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -481,10 +481,12 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK // TODO: use occupancy api to calculate appropriate batch size. } - Argument(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, + // Client is expected to manually copy the kernel arguments to the device therefore there is + // no point in setting tensor device pointers for the argument structure. + Argument(std::vector>&, + std::vector>&, + std::vector>&, + std::vector&, std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -499,16 +501,6 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK grid_size_{0}, k_batch_{kbatch} { - - if(!(group_count_ == ck::type_convert(p_As.size()) && - group_count_ == ck::type_convert(p_Bs.size()) && - ((NumDTensor == 0 && p_Ds.size() == 0) || - group_count_ == ck::type_convert(p_Ds.size())) && - group_count_ == ck::type_convert(p_Es.size()))) - { - throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); - } - gemm_desc_kernel_arg_.reserve(group_count_); index_t group_id = 0; @@ -603,7 +595,7 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK auto karg = GemmTransKernelArg({p_as_grid, p_bs_grid, p_ds_grid, - type_convert(p_Es[g]), + nullptr, AverM, N, K, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 67e4148177..574e561925 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -736,6 +736,19 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static auto MakeInvoker() { return Invoker{}; } + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + // polymorphic std::unique_ptr MakeArgumentPointer(std::vector>& p_As, diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp index 304e523ca3..9c79422406 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -352,23 +352,23 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, EDataType, remove_cvref_t>>::type; + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + for(std::size_t i = 0; i < group_count; i++) { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultiABD; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i], g_bs_k_n[i], g_ds_m_n[i], diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index ad41bc797d..3379fd15d1 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -58,7 +58,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_multi_abd_fixed_nk.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) @@ -296,22 +295,3 @@ message(VERBOSE "ckProfiler libs: ${PROFILER_LIBS}") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS}) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) - -## Defining specific operation targets - -macro(define_profiler_target NAME SOURCES LIBS) - add_executable(${NAME} profiler.cpp ${SOURCES}) - target_compile_options(${NAME} PRIVATE -Wno-global-constructors) - - if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) - target_compile_options(${NAME} PRIVATE --offload-compress) - endif() - - target_link_libraries(${NAME} PRIVATE utility getopt::getopt ${LIBS}) - - rocm_install(TARGETS ${NAME} COMPONENT profiler) -endmacro() - -define_profiler_target(ckProfiler_fixed_nk - "profile_grouped_gemm_multi_abd_fixed_nk.cpp" - "device_grouped_gemm_fixed_nk_multi_abd_instance") From 06dc334994c5b4775874ac31211b6b36622d641a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= Date: Tue, 20 Jan 2026 14:41:07 +0000 Subject: [PATCH 07/10] code cleanups --- ...gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 2 +- ...device_grouped_gemm_multi_abd_fixed_nk.hpp | 5 -- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 19 ++---- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 24 +++++++- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 2 +- .../cpu/reference_gemm_multi_abd.hpp | 24 ++++---- .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 2 - ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 2 - ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 2 - ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 2 - ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 13 +++-- test/CMakeLists.txt | 2 +- .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 58 +++++++++---------- 13 files changed, 78 insertions(+), 79 deletions(-) diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp index f4043936c3..2f3ca64f0f 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp @@ -68,7 +68,7 @@ using CDEElementOp = Add; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK -// clang-format off + // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp index bb889e8026..9c185923ca 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp @@ -5,7 +5,6 @@ #include #include -#include #include "device_grouped_gemm_multi_abd.hpp" @@ -72,10 +71,6 @@ struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD { - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 62f47758d1..4059da5822 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -727,8 +727,8 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK // vector load size. if constexpr(GemmSpec != GemmSpecialization::Default) { - // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} - // layout, thus we have to adapt it to the {M,K} or {N,K} layout. + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; @@ -737,19 +737,8 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); - bool isABlockTransferValid = (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); - if(!isABlockTransferValid && ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - printf("Invalid block transfer for A block.\n"); - } - - bool isBBlockTransferValid = (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); - if(!isBBlockTransferValid && ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - printf("Invalid block transfer for B block.\n"); - } - - supported &= isABlockTransferValid && isBBlockTransferValid; + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 574e561925..f898c6ff4d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -725,6 +725,26 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } } + for(index_t i = 0; i < arg.group_count_; i++) + { + if(get_warp_size() == 64) + { + if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + true) + { + supported = false; + } + } + else + { + if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + true) + { + supported = false; + } + } + } + return supported; } @@ -734,8 +754,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeInvoker() { return Invoker{}; } - static auto MakeArgument(std::vector>& p_As, std::vector>& p_Bs, std::vector>& p_Ds, @@ -749,6 +767,8 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; } + static auto MakeInvoker() { return Invoker{}; } + // polymorphic std::unique_ptr MakeArgumentPointer(std::vector>& p_As, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index e66ede1afe..b7b88d4920 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -400,7 +400,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); } - __host__ __device__ __device__ static auto CalculateMPadded(index_t M) + __host__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp index aef3c8a45c..8a06322ce4 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp @@ -39,7 +39,7 @@ template struct ReferenceGemmMultiABD : public device::BaseOperator -{ +{ // Argument struct Argument : public device::BaseArgument { @@ -93,9 +93,9 @@ struct ReferenceGemmMultiABD : public device::BaseOperator // result auto data_refs1 = ck::tie(a_m_k(m, k)); // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, - Number{}); + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, + Number{}); auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); unpack(arg.a_element_op_, data_refs); } @@ -109,9 +109,9 @@ struct ReferenceGemmMultiABD : public device::BaseOperator // result auto data_refs1 = ck::tie(b_k_n(k, n)); // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, - Number{}); + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, + Number{}); auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); unpack(arg.b_element_op_, data_refs); } @@ -130,8 +130,8 @@ struct ReferenceGemmMultiABD : public device::BaseOperator auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); ref_invoker.Run(ref_argument); @@ -142,9 +142,9 @@ struct ReferenceGemmMultiABD : public device::BaseOperator // compulsory auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n)); // optional (if multiple Ds) - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, - Number{}); + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, + Number{}); auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); unpack(arg.cde_element_op_, data_refs); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 7820549ad8..0879bea4ea 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -20,7 +20,6 @@ using Multiply = ck::tensor_operation::element_wise::Multiply; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; - #if defined(CK_USE_XDL) // RRR void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( @@ -342,7 +341,6 @@ void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_insta PassThrough>>>& instances); #endif // CK_USE - // GEMM + Add + Gelu template , S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 7bbfe30a4a..fb71c260e4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -8,7 +8,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -84,7 +83,6 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index df2468a2e1..95365c82e7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -8,7 +8,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -84,7 +83,6 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp index 9c79422406..b0d6b03fe3 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -285,7 +285,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, ds_stride[i] = StrideDs[g]; }); - g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); + g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ @@ -387,11 +387,12 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, auto argument_ptr = gemm_ptr->MakeArgumentPointer( g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs); - if (!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Gemm incompatible with runtime set parameters. Skipping..." << std::endl; + std::cout << "Gemm incompatible with runtime set parameters. Skipping..." + << std::endl; } continue; @@ -400,13 +401,15 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer()); - DeviceMem grouped_gemm_kernel_args_dev(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + DeviceMem grouped_gemm_kernel_args_dev( + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), grouped_gemm_kernel_args_host.data(), gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), hipMemcpyHostToDevice)); - gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 488a15a11e..9fee3b5697 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -251,7 +251,7 @@ function(add_gtest_executable TEST_NAME) endfunction() add_compile_options(-Wno-c++20-extensions) -# add_subdirectory(ck_tile) +add_subdirectory(ck_tile) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp index c5a55bac22..4eaf54a7b5 100644 --- a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -128,8 +128,8 @@ class TestGroupedGemmMultiABDFixedNK : public testing::Test template void SetTupleStrides(std::vector& strides, - const std::vector& rows, - const std::vector& cols) const + const std::vector& rows, + const std::vector& cols) const { if constexpr(Layouts::Size() > 0) { @@ -146,7 +146,7 @@ class TestGroupedGemmMultiABDFixedNK : public testing::Test const std::vector& StrideAs = {}, const std::vector& StrideBs = {}, const std::vector& StrideDs = {}, - const std::vector& StrideE = {}) + const std::vector& StrideE = {}) { std::vector stride_as = StrideAs; std::vector stride_bs = StrideBs; @@ -190,32 +190,32 @@ class TestGroupedGemmMultiABDFixedNK : public testing::Test const std::vector& StrideDs, const std::vector& StrideE) { - bool pass = ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl( - verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideDs, - StrideE, - k_batches_, - n_warmup_, - n_iter_); + bool pass = + ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideE, + k_batches_, + n_warmup_, + n_iter_); EXPECT_TRUE(pass); } }; From 96d5a969676c154d8089487e877e354506d70451 Mon Sep 17 00:00:00 2001 From: illsilin_amdeng Date: Tue, 20 Jan 2026 13:35:38 -0800 Subject: [PATCH 08/10] fix clang format --- .../impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 2 +- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 2 +- profiler/include/profiler/profile_gemm_multi_abd_impl.hpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 4059da5822..44966d0395 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -76,7 +76,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if(karg.M == 0 || karg.N == 0 || karg.K == 0) return; - // using e_data_type = remove_cvref_t>; + // using e_data_type = remove_cvref_t>; #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f1595d8f15..a34170df88 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -333,7 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 using typename Base::DsGridPointer; using AsDataType_ = AsDataType; using BsDataType_ = BsDataType; - using EDataType_ = EDataType; + using EDataType_ = EDataType; struct Problem { diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp index 70baf11b5f..83851234ac 100644 --- a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -201,8 +201,8 @@ bool profile_gemm_multi_abd_impl(int do_verification, AComputeType, BComputeType>; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); From 06bac70bf4cd2d4eeffd9436ca0fa242d3aeec30 Mon Sep 17 00:00:00 2001 From: Zoltan Lakatos Date: Wed, 21 Jan 2026 09:38:07 +0000 Subject: [PATCH 09/10] typo and compilation fixes related to reference gemm --- ...emm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | 2 +- ...d_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 2 +- ...gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 2 +- ...ed_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp | 6 +++--- ...ice_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 3 +-- ...vice_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 2 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- .../profiler/profile_gemm_multi_abd_impl.hpp | 17 +++++++++-------- ...ile_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 4 ++-- .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 5 ++++- 10 files changed, 24 insertions(+), 21 deletions(-) diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp index 5aaff10f13..644726d53c 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp @@ -398,7 +398,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp index 2f3ca64f0f..da946c71d4 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp @@ -392,7 +392,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 28b3fa9213..bf503cb2a1 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -398,7 +398,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 032842b9eb..f9b5dec9d5 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -47,9 +47,9 @@ using B0DataType = F16; using BsDataType = ck::Tuple; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = F32; +using D0DataType = F16; using DsDataType = ck::Tuple; -using EDataType = F32; +using EDataType = F16; using A0Layout = Row; using A1Layout = Row; @@ -394,7 +394,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 44966d0395..17f1f573a8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -76,7 +76,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if(karg.M == 0 || karg.N == 0 || karg.K == 0) return; - // using e_data_type = remove_cvref_t>; #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && @@ -651,7 +650,7 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK { if(arg.grouped_gemm_kernel_args_dev == nullptr) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); } float ave_time = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index f898c6ff4d..637864d907 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -605,7 +605,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK if(arg.grouped_gemm_kernel_args_dev == nullptr) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); } float ave_time = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 71f2d737e6..6555ead7d2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -696,7 +696,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK>>::type; using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultiABD; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument( + as_m_k, bs_k_n, ds_m_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); ref_invoker.Run(ref_argument); } diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp index b0d6b03fe3..eea72b324d 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -97,7 +97,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, auto generateInputTupleA = [&](std::size_t g) { if constexpr(NumATensor == 0) { - return ck::Tuple<>(); + static_assert("Gemm problem should have at least 1 A tensor."); } else { @@ -114,7 +114,7 @@ bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, auto generateInputTupleB = [&](std::size_t g) { if constexpr(NumBTensor == 0) { - return ck::Tuple<>(); + static_assert("Gemm problem should have at least 1 B tensor."); } else { diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp index 4eaf54a7b5..74948d58e4 100644 --- a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -273,7 +273,10 @@ TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} + if(argc == 1) + { + // Run with default arguments. + } else if(argc == 3) { param_mask = strtol(argv[1], nullptr, 0); From 15b26a3c5edd3c47737eb35092a5d0143adea548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= Date: Wed, 21 Jan 2026 16:30:02 +0000 Subject: [PATCH 10/10] fix compilation error due to std::remove_cvref_t --- test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp index 74948d58e4..ea5a06f194 100644 --- a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -9,6 +9,7 @@ #include "ck/utility/data_type.hpp" #include "ck/ck.hpp" +#include "ck/utility/type.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" @@ -134,7 +135,7 @@ class TestGroupedGemmMultiABDFixedNK : public testing::Test if constexpr(Layouts::Size() > 0) { // As of now multi ABD implementation supports only tensors with matching layouts. - using Layout = std::remove_cvref_t{}, Layouts>>; + using Layout = ck::remove_cvref_t{}, Layouts>>; SetStrides(strides, rows, cols); } }