diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index ec536f72878..ce34bc2c70d 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,6 +13,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} gemm_quant.cpp gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped_fp8.cpp + gemm_abquant_quantgrouped_bf8.cpp + gemm_abquant_quantgrouped_fp4.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 155f19881ea..d1051232f0c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -6,127 +6,17 @@ template using GemmConfig = GemmConfigQuantPrefill; +void abquant_quantgrouped_fp4_instance_factory( + std::unordered_map>& lut); +void abquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void abquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); + void abquant_quantgrouped_instance_factory( std::unordered_map>& lut) { - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; + abquant_quantgrouped_fp4_instance_factory(lut); + abquant_quantgrouped_fp8_instance_factory(lut); + abquant_quantgrouped_bf8_instance_factory(lut); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp new file mode 100644 index 00000000000..f2d542ba490 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +void abquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp new file mode 100644 index 00000000000..2124503a27d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +void abquant_quantgrouped_fp4_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp new file mode 100644 index 00000000000..5147046e201 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +void abquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 8de58b0a309..7f12e2c47f4 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8") + "or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9afd..d969f33063c 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -9,6 +9,7 @@ #include #include #include +#include #include "ck_tile/core/config.hpp" #include "ck_tile/ops/common/utils.hpp" @@ -33,10 +34,9 @@ template ); - using ComputeDataType = std::conditional_t; + + // Use automatically determined compute type from + using ComputeDataType = void; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -77,7 +77,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, ck_tile::BaseGemmPipelineAgBgCrMem, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::ABQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -181,30 +184,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledPermuteN>>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -563,8 +564,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -604,14 +604,32 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( a_m_k); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-6.0f, 6.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + + if constexpr(std::is_same_v) + { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-6.0f, 6.0f, fill_seed(gen)}( + b_k_n); + } else { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( @@ -812,10 +830,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if(arg_parser.get_int("v") == 1) { + std::cout << "Performing CPU verification..." << std::endl; + ck_tile::HostTensor c_m_n_host_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); + // Track start time for reference operation + auto start_reference_tick = std::chrono::high_resolution_clock::now(); if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( @@ -889,6 +914,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); + // "Stop" our timer + auto verification_finished_tick = std::chrono::high_resolution_clock::now(); + if(!pass) { std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) @@ -896,6 +924,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, << std::endl; } std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + // Calculate and display reference timing + using DurationType = std::chrono::duration; + double reference_sec = std::chrono::duration_cast(verification_finished_tick - + start_reference_tick) + .count(); + double verification_sec = std::chrono::duration_cast( + verification_finished_tick - start_verification_tick) + .count(); + float reference_msec = static_cast(reference_sec * 1e3); + float verification_msec = static_cast(verification_sec * 1e3); + + std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took " + << reference_msec << "ms, verification took " << verification_msec << "ms." + << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -926,6 +969,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) } if constexpr(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 01e1d00b591..7a0d9dd3fcc 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -91,6 +91,7 @@ #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/mfma_compute_types.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/print.hpp" #include "ck_tile/core/utility/random.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558add..8f9dd30bda6 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b5..42886b8ced2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index cc23ce71a83..f79bbd1be9a 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -6,6 +6,7 @@ #include #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" #if defined(__gfx950__) @@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +#if CK_TILE_USE_CUSTOM_DATA_TYPE +using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2))); +#else +using fp8x2_t = fp8_t __attribute__((ext_vector_type(2))); +#endif + // Helpers: constexpr-safe access to elements of ext_vector_type(2) // Some compilers don't allow operator[] in constant expressions for vector types. // We use bit_cast to a trivially copyable representation to extract lanes. @@ -98,6 +105,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } @@ -105,6 +114,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); } + CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); } template CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const @@ -145,6 +156,26 @@ struct pk_float4_e2m1_t bit_cast(static_cast(0xC400)), // -4 bit_cast(static_cast(0xC600)) // -6 }; + + // FP8 = E4M3. Finite and normal values should be bit-compatible between FNUZ and OCP + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x30)), // 0.5 + fp8_t(static_cast(0x38)), // 1 + fp8_t(static_cast(0x3C)), // 1.5 + fp8_t(static_cast(0x40)), // 2 + fp8_t(static_cast(0x44)), // 3 + fp8_t(static_cast(0x48)), // 4 + fp8_t(static_cast(0x4C)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB0)), // -0.5 + fp8_t(static_cast(0xB8)), // -1 + fp8_t(static_cast(0xBC)), // -1.5 + fp8_t(static_cast(0xC0)), // -2 + fp8_t(static_cast(0xC4)), // -3 + fp8_t(static_cast(0xC8)), // -4 + fp8_t(static_cast(0xCC)) // -6 + }; #endif }; @@ -408,6 +439,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; + // #endif +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; + // #endif +} #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { @@ -415,7 +467,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, + e2m1_to_fp32_table[_unpack(number<1>{})] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { @@ -428,6 +481,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * scale)}; } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + return type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale; +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + return fp8x2_t{ + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)}; +} #endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 13a43f8b5c9..9365e65fed7 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/random.hpp" #include @@ -23,6 +24,11 @@ struct pk_int4_t type data; CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} + + // NOTE: added for interface compatibility with pk_fp4_t + // Other data types could be added for greater similarity + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const; + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } }; // limits @@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const +{ + return pk_int4_t_to_fp32x2_t(*this); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56ea..def054f4155 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/utility/mfma_compute_types.hpp b/include/ck_tile/core/utility/mfma_compute_types.hpp new file mode 100644 index 00000000000..540793850c8 --- /dev/null +++ b/include/ck_tile/core/utility/mfma_compute_types.hpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + +#include + +namespace ck_tile { + +namespace detail { +template +struct t +{ + using type = T; +}; + +// Helper method to automatically determine compute type +// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. +template +struct auto_compute_type +{ + static constexpr auto Resolve() + { + using LargestInputType = largest_type_t; + if constexpr(is_packed_type_v) + { + return t{}; + } + else + { + return t{}; + } + } + + using type = typename decltype(Resolve())::type; +}; + +// Helper method to determine compute type, defaulting an explicitly passed-in compute type +template +struct mfma_compute_type +{ + using type = std::conditional_t, + typename auto_compute_type::type, + ComputeDataType>; +}; + +}; // namespace detail + +template +using mfma_compute_type_t = + typename detail::mfma_compute_type::type; + +// Helper method to determine compute type, defaulting to input data type +// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, +// ComputeDataType is used. +template +using mfma_compute_type_from_input_t = std::conditional_t< + is_packed_type_v, + std::conditional_t, ComputeDataType, OtherDataType>, + ThisDataType>; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f07e25e19cb..c11d180839b 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + #include #include #include @@ -187,4 +189,19 @@ template using tuple_element_or_default_t = typename tuple_element_or_default::type; +// Helper struct to determine if a type is packed (more than 1 element per byte) +template +struct is_packed_type +{ + static constexpr bool value = numeric_traits::PackedSize > 1; +}; + +template +static constexpr bool is_packed_type_v = is_packed_type::value; + +// Helper definition to take the largest sizes type +template +using largest_type_t = + std::conditional_t= sizeof(BDataType), ADataType, BDataType>; + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264c..f021e1a85f3 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -141,43 +141,46 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); + // Pre-convert A/B tensors to AccData type + // This prevents doing slow reconversions for each row/column + HostTensor a_acc(a_m_k.mDesc); + HostTensor b_acc(b_k_n.mDesc); + + a_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const ADataType pk_val = a_element_op(a_m_k(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[1] & 1) ? fp32_val.hi : fp32_val.lo; + } + else + { + self(index) = ck_tile::type_convert(a_element_op(a_m_k(index))); + } + }); + + b_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const BDataType pk_val = b_element_op(b_k_n(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[0] & 1) ? fp32_val.hi : fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + self(index) = fp8_to_float_raw(b_element_op(b_k_n(index))); + } + else + { + self(index) = ck_tile::type_convert(b_element_op(b_k_n(index))); + } + }); + auto f_mn = [&](auto m, auto n) { AccDataType v_acc = 0; constexpr std::size_t kGroupK = BQuantGroupSize::kK; - // ---- A loader: dequant A(m,k) into AccDataType ---- - auto load_a = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else - { - return ck_tile::type_convert(a_element_op(a_m_k(m, k))); - } - }; - - // ---- B loader: dequant B(k,n) into AccDataType ---- - auto load_b = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else if constexpr(std::is_same_v) - { - return fp8_to_float_raw(b_element_op(b_k_n(k, n))); - } - else - { - return ck_tile::type_convert(b_element_op(b_k_n(k, n))); - } - }; - // ---- a scale loader for a given K-group index ---- auto load_scale_a = [&](ck_tile::index_t k_group) -> float { const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM; @@ -224,8 +227,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // unscaled accumulation within this K-group for(std::size_t k = k_begin; k < k_end; ++k) { - const AccDataType v_a = load_a(k); - const AccDataType v_b = load_b(k); + const AccDataType v_a = a_acc(m, k); + const AccDataType v_b = b_acc(k, n); v_block_acc += v_a * v_b; } diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 10c2a1e4df7..fe16fe1418c 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -8,7 +8,7 @@ namespace ck_tile { -template +template struct InterleavedPKTypeLoader { template @@ -21,10 +21,15 @@ struct InterleavedPKTypeLoader constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; const auto in_dstr_tensors = load_tile(warp_window); + // NOTE: we rely on types packing neatly here + using RawSrcType = typename SrcDataType::type; + constexpr auto PackedSize = numeric_traits::PackedSize; + + using SrcVectorType = RawSrcType __attribute__((ext_vector_type(UnaryOpSize / PackedSize))); using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + in_dstr_tensors.get_thread_buffer().template get_as()[i]); }); } }; @@ -37,10 +42,11 @@ template CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v) + if constexpr(numeric_traits::PackedSize > 1) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); + InterleavedPKTypeLoader::load_interleaved_pk_type( + dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ca9af0a7a88..3f58eceb333 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -397,6 +397,29 @@ struct PassThroughPack8 y.hi = i4_to_bf8x4(bit_cast(x) >> 8); #endif } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const + { + pk_fp4_t f0 = pk_fp4_t{x[0]}; + pk_fp4_t f1 = pk_fp4_t{x[1]}; + pk_fp4_t f2 = pk_fp4_t{x[2]}; + pk_fp4_t f3 = pk_fp4_t{x[3]}; + + fp8x2_t x0 = f0.to_fp8x2(); + fp8x2_t x1 = f1.to_fp8x2(); + fp8x2_t x2 = f2.to_fp8x2(); + fp8x2_t x3 = f3.to_fp8x2(); + + y[0] = x0[0]; + y[1] = x0[1]; + y[2] = x1[0]; + y[3] = x1[1]; + y[4] = x2[0]; + y[5] = x2[1]; + y[6] = x3[0]; + y[7] = x3[1]; + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157cb..f922e95d372 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" @@ -252,11 +253,20 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - using WarpGemm = WarpGemmDispatcher; + using BTypeToUse = mfma_compute_type_from_input_t; + + using WarpGemm = WarpGemmDispatcher f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -188,7 +190,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg typename BFlatBlockTensor, typename AQBlockTensor, typename BQBlockTensor, - typename ABlockWindow> + typename ABlockWindow, + index_t UnaryOpSize = 8> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, ABlockTensor& a_warp_tensor, BFlatBlockTensor& b_warp_tensor, @@ -232,8 +235,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); + + load_int4_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); } // barrier // Could be deleted diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d139..631ce56bcab 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -106,9 +106,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // 4. i4, bf8, (fp8/fp32) -> f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -133,9 +135,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; + // A/B DataType get converted from PkInt4/PkFp8 during loading + using OverrideADataType = ComputeDataType; + using OverrideBDataType = ComputeDataType; + using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -261,9 +264,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS + load_int4_tile( a_warp_tile_, a_block_window); - // If B datatype were pkint4 it would be converted prior to storing in LDS load_int4_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp index 095275e60b0..180bc77ecaa 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp @@ -10,9 +10,11 @@ namespace ck_tile { -struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmABQuantPipelineAgBgCrDefaultPolicy : + // public UniversalGemmPipelineAgBgCrPolicy + public UniversalGemmBasePolicy { - using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base = UniversalGemmBasePolicy; using Base::I0; using Base::I1; using Base::I2; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cd70c2ca862..5798b44e4e0 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using AQuantGroupSize = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3())>; + // A/B DataType gets converted from PkInt4/PkFp8 during loading + using OverrideADataType = BlockGemm::OverrideADataType; + using OverrideBDataType = BlockGemm::OverrideBDataType; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -277,9 +278,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -299,9 +300,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = decltype(make_static_distributed_tensor(AQBlockTileDistr{})); using BQBlockTile = @@ -354,7 +355,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -366,7 +367,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -402,7 +403,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -413,7 +415,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); @@ -486,7 +488,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ADataType gets converted during loading from PkInt4/PkFp4 + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -536,9 +539,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, m, @@ -586,9 +589,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: ADataType PkInt4/PkFp4 gets converted during loading + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - // Note: BDataType PkInt4 gets converted during loading + // Note: BDataType PkInt4/PkFp4 gets converted during loading [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 39b00d2501b..28d7bcb2d02 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -21,23 +21,26 @@ template -struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase +struct GemmQuantPipelineProblemBase + : public GemmPipelineProblemBase> { - using Base = GemmPipelineProblemBase; + + using Base = + GemmPipelineProblemBase>; using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index 80e41cad458..d4fcc378daa 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -93,13 +93,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - - using WarpGemm = WarpGemmDispatcher #include "ck_tile/core.hpp" +#include "ck_tile/core/utility/mfma_compute_types.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -234,36 +236,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load + auto a_dram_tile_distribution = + PipelinePolicy::template MakeADramTileDistribution(); + auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); // ping-pong window for A LDS + auto a_warp_tile_distribution = + make_static_tile_distribution(typename WG::AWarpDstrEncoding{}); + auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); auto a_warp_window_pong_tmp = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); statically_indexed_array< statically_indexed_array, @@ -308,8 +316,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe b_flat_dram_block_window_tmp.get_window_origin(), b_flat_distribution); - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = mfma_compute_type_from_input_t; using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); // pingpong buffer for B @@ -349,7 +356,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -389,15 +396,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; + using ATypeToUse = mfma_compute_type_from_input_t; + using ATileType = + decltype(make_static_distributed_tensor(a_warp_tile_distribution)); + statically_indexed_array a_warp_tensor; static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -430,7 +438,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -442,8 +450,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // Next K @@ -455,7 +463,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -485,8 +493,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; HotLoopScheduler(); @@ -503,7 +511,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -525,8 +533,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // GEMM loopK diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 5749a8d3b27..0a15dc32ab8 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -44,6 +44,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base + test_gemm_quant_abquant_a4w4_base.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding + test_gemm_quant_abquant_a4w4_padding.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle + test_gemm_quant_abquant_a4w4_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # AQuant tests add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp new file mode 100644 index 00000000000..5e2403f7d17 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp new file mode 100644 index 00000000000..1e496d5b642 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK) +{ + this->run_test_with_validation(1024, 1024, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN) +{ + this->run_test_with_validation(1024, 832, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM) +{ + this->run_test_with_validation(832, 1024, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK) +{ + this->run_test_with_validation(832, 832, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp new file mode 100644 index 00000000000..43051c8d088 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // RCR layout with RowMajor AQ, ColumnMajor BQ + // PreshuffleB = true && TransposeC = false + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 8c9955da749..da06b098f7e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -207,7 +207,7 @@ template <> struct QuantTypeTraits { template - using ComputeDataType = BDataType; // For AQuant, compute type is BDataType + using ComputeDataType = void; // Use automatically determined compute type static constexpr const char* name = "abquant"; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 79c86935efc..d9bd01051d8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -924,8 +924,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase>; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType,