From 9891903f215ecc5900adb9b925c0240e6050e7fd Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Wed, 14 Jan 2026 08:07:21 +0000 Subject: [PATCH 01/14] chore: split block scale example instances in more separate files to speed up compile times --- .../38_block_scale_gemm/CMakeLists.txt | 11 + .../gemm_abquant_quantgrouped.cpp | 130 +-------- .../gemm_abquant_quantgrouped_bf8.cpp | 72 +++++ .../gemm_abquant_quantgrouped_fp4.cpp | 27 ++ .../gemm_abquant_quantgrouped_fp8.cpp | 72 +++++ .../gemm_bquant_quantgrouped_preshuffleb.cpp | 218 +------------- ...mm_bquant_quantgrouped_preshuffleb_bf8.cpp | 67 +++++ ..._bquant_quantgrouped_preshuffleb_bf8i4.cpp | 69 +++++ ...mm_bquant_quantgrouped_preshuffleb_fp8.cpp | 67 +++++ ..._bquant_quantgrouped_preshuffleb_fp8i4.cpp | 69 +++++ ...mm_bquant_quantgrouped_preshufflequant.cpp | 271 +----------------- ...quant_quantgrouped_preshufflequant_bf8.cpp | 75 +++++ ...ant_quantgrouped_preshufflequant_bf8i4.cpp | 77 +++++ ...quant_quantgrouped_preshufflequant_fp8.cpp | 76 +++++ ...ant_quantgrouped_preshufflequant_fp8i4.cpp | 77 +++++ 15 files changed, 795 insertions(+), 583 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 28e52b92754..cdadd3d3137 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,6 +13,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} gemm_quant.cpp gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped_fp8.cpp + gemm_abquant_quantgrouped_bf8.cpp + gemm_abquant_quantgrouped_fp4.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp @@ -21,7 +24,15 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp gemm_bquant_quantgrouped_preshufflequant.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp gemm_quant_rowcol.cpp gemm_quant_tensor.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 155f19881ea..d1051232f0c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -6,127 +6,17 @@ template using GemmConfig = GemmConfigQuantPrefill; +void abquant_quantgrouped_fp4_instance_factory( + std::unordered_map>& lut); +void abquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void abquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); + void abquant_quantgrouped_instance_factory( std::unordered_map>& lut) { - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "preshuffleb", - "non-preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; + abquant_quantgrouped_fp4_instance_factory(lut); + abquant_quantgrouped_fp8_instance_factory(lut); + abquant_quantgrouped_bf8_instance_factory(lut); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp new file mode 100644 index 00000000000..f2d542ba490 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_bf8.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +void abquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp new file mode 100644 index 00000000000..dec62622612 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +void abquant_quantgrouped_fp4_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp new file mode 100644 index 00000000000..5147046e201 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp8.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +void abquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp index b32356c29d7..81bdc74fcad 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -11,212 +11,20 @@ template using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; #endif +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut); + void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut) { - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp new file mode 100644 index 00000000000..5c399bee854 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp new file mode 100644 index 00000000000..24d22804013 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp new file mode 100644 index 00000000000..5c5dae1413d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp new file mode 100644 index 00000000000..a4cb78bf65a --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index 62ca34b057b..8915a607c85 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -6,265 +6,20 @@ template using GemmConfig = GemmConfigPreshuffleBQuantPrefill; +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut); + void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) { - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp new file mode 100644 index 00000000000..956ca893be4 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp new file mode 100644 index 00000000000..5218eadaf88 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp new file mode 100644 index 00000000000..3caac757e63 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp new file mode 100644 index 00000000000..b9604dc9fd7 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} From e43fd3375bfacee8aaf76a7d0555d9dd135ebe12 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Wed, 14 Jan 2026 08:08:30 +0000 Subject: [PATCH 02/14] wip: fp4 scaffolding for abquant --- .../38_block_scale_gemm/gemm_quant.cpp | 2 +- .../run_gemm_quant_example.inc | 77 ++++++++++++------- .../arch/amd_buffer_addressing_builtins.hpp | 2 +- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 8 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 26 ++++--- 5 files changed, 73 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 940c1b8cf3f..b6fadf08239 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8") + "or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9afd..47be325c30a 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -9,6 +9,7 @@ #include #include #include +#include #include "ck_tile/core/config.hpp" #include "ck_tile/ops/common/utils.hpp" @@ -19,6 +20,18 @@ #include "ck_tile/ops/epilogue.hpp" #include "gemm_utils.hpp" +template +struct fallback_data_type +{ + using type = std::conditional_t || + std::is_same_v, + FallbackDataType, + DataType>; +}; + +template +using fallback_data_type_t = fallback_data_type::type; + template ); - using ComputeDataType = std::conditional_t; + + using CandidateComputeDataType = + std::conditional_t; + + using ComputeDataType = fallback_data_type_t; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -77,7 +94,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, ck_tile::BaseGemmPipelineAgBgCrMem, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::ABQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -181,30 +201,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledPermuteN>>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -926,6 +944,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) } if constexpr(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b5..42886b8ced2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d139..46e43a6adc4 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -106,9 +106,13 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // 4. i4, bf8, (fp8/fp32) -> f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 004fb18e0b5..b6855de989a 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -408,19 +408,23 @@ struct QuantGemmKernel const index_t k_size, const index_t i_m) { + constexpr auto packing_factor = std::is_same_v ? 2 : 1; + // Step 1: Create tensor view for A const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, k_size), + make_tuple(kargs.M / packing_factor, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } else { + static_assert(packing_factor == 1); + return make_naive_tensor_view( a_ptr, make_tuple(k_size, kargs.M), @@ -434,13 +438,15 @@ struct QuantGemmKernel const auto& a_pad_view = [&]() { if constexpr(std::is_same_v) { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view( + a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { + static_assert(packing_factor == 1); return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), @@ -452,13 +458,15 @@ struct QuantGemmKernel const auto& a_block_window = [&]() { if constexpr(std::is_same_v) { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); + return make_tile_window( + a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); } else { + static_assert(packing_factor == 1); return make_tile_window(a_pad_view, make_tuple(number{}, number{}), From 6dea234c3cb13dc0cf0784b6ec01afd74c79c693 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Wed, 14 Jan 2026 14:59:02 +0000 Subject: [PATCH 03/14] feat: add fp4 decoding-while-loading to abquant pipeline --- .../gemm_abquant_quantgrouped_fp4.cpp | 4 +- .../run_gemm_quant_example.inc | 14 +----- include/ck_tile/core/numeric/pk_fp4.hpp | 29 ++++++++++++ include/ck_tile/core/numeric/vector_type.hpp | 1 + .../ops/common/load_interleaved_pk_type.hpp | 16 +++++-- .../unary_element_wise_operation.hpp | 23 +++++++++ ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 47 +++++++++++++++---- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 26 ++++------ .../gemm_abquant_pipeline_ag_bg_cr_policy.hpp | 6 ++- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 36 +++++++------- 10 files changed, 139 insertions(+), 63 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp index dec62622612..2124503a27d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_fp4.cpp @@ -14,8 +14,8 @@ void abquant_quantgrouped_fp4_instance_factory( [](const ck_tile::ArgParser& arg_parser) { using AQuantGroupSize = ck_tile::QuantGroupShape>; using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 47be325c30a..639531b0997 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -20,18 +20,6 @@ #include "ck_tile/ops/epilogue.hpp" #include "gemm_utils.hpp" -template -struct fallback_data_type -{ - using type = std::conditional_t || - std::is_same_v, - FallbackDataType, - DataType>; -}; - -template -using fallback_data_type_t = fallback_data_type::type; - template ; - using ComputeDataType = fallback_data_type_t; + using ComputeDataType = ck_tile::detail::compute_type_t; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index cc23ce71a83..0120c755607 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -6,6 +6,7 @@ #include #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" #if defined(__gfx950__) @@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +#if CK_TILE_USE_CUSTOM_DATA_TYPE +using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2))); +#else +using fp8x2_t = fp8_t __attribute__((ext_vector_type(2))); +#endif + // Helpers: constexpr-safe access to elements of ext_vector_type(2) // Some compilers don't allow operator[] in constant expressions for vector types. // We use bit_cast to a trivially copyable representation to extract lanes. @@ -98,6 +105,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } @@ -105,6 +114,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); } + CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); } template CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const @@ -408,6 +419,24 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } + +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data, scale); +#else + return fp8_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; +#endif +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data, scale); +#else + return fp8x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; +#endif +} #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56ea..def054f4155 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 10c2a1e4df7..4743d572f72 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -8,7 +8,7 @@ namespace ck_tile { -template +template struct InterleavedPKTypeLoader { template @@ -21,10 +21,15 @@ struct InterleavedPKTypeLoader constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; const auto in_dstr_tensors = load_tile(warp_window); + // NOTE: we rely on types packing neatly here + using RawSrcType = typename SrcDataType::type; + constexpr auto PackedSize = numeric_traits::PackedSize; + + using SrcVectorType = RawSrcType __attribute__((ext_vector_type(UnaryOpSize / PackedSize))); using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + in_dstr_tensors.get_thread_buffer().template get_as()[i]); }); } }; @@ -37,10 +42,11 @@ template CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || std::is_same_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); + InterleavedPKTypeLoader::load_interleaved_pk_type( + dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ca9af0a7a88..3f58eceb333 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -397,6 +397,29 @@ struct PassThroughPack8 y.hi = i4_to_bf8x4(bit_cast(x) >> 8); #endif } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const + { + pk_fp4_t f0 = pk_fp4_t{x[0]}; + pk_fp4_t f1 = pk_fp4_t{x[1]}; + pk_fp4_t f2 = pk_fp4_t{x[2]}; + pk_fp4_t f3 = pk_fp4_t{x[3]}; + + fp8x2_t x0 = f0.to_fp8x2(); + fp8x2_t x1 = f1.to_fp8x2(); + fp8x2_t x2 = f2.to_fp8x2(); + fp8x2_t x3 = f3.to_fp8x2(); + + y[0] = x0[0]; + y[1] = x0[1]; + y[2] = x1[0]; + y[3] = x1[1]; + y[4] = x2[0]; + y[5] = x2[1]; + y[6] = x3[0]; + y[7] = x3[1]; + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 46e43a6adc4..4f3f03cd5ad 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -12,6 +12,36 @@ namespace ck_tile { +// Utility functions for dealing with packed data types +namespace detail { +template +struct is_packed_type +{ + static constexpr bool value = numeric_traits::PackedSize > 1; +}; + +template +constexpr bool is_packed_type_v = is_packed_type::value; + +template +struct compute_type +{ + using type = std::conditional_t, FallbackType, DataType>; +}; + +template +using compute_type_t = typename compute_type::type; + +template +struct gemm_type +{ + using type = compute_type_t>; +}; + +template +using gemm_type_t = typename gemm_type::type; +} // namespace detail + // A is block window on shared memory // AQ (scale tensor) is block distributed tensor. // BQ (scale tensor) is block distributed tensor. @@ -107,12 +137,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase static_assert( (std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v) && + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v) && + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -137,9 +165,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading + // A/B DataType get converted from PkInt4/PkFp8 during loading + using OverrideADataType = + detail::gemm_type_t; using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; + detail::gemm_type_t; + using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -265,9 +296,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS + load_int4_tile( a_warp_tile_, a_block_window); - // If B datatype were pkint4 it would be converted prior to storing in LDS load_int4_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index b6855de989a..004fb18e0b5 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -408,23 +408,19 @@ struct QuantGemmKernel const index_t k_size, const index_t i_m) { - constexpr auto packing_factor = std::is_same_v ? 2 : 1; - // Step 1: Create tensor view for A const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M / packing_factor, k_size), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } else { - static_assert(packing_factor == 1); - return make_naive_tensor_view( a_ptr, make_tuple(k_size, kargs.M), @@ -438,15 +434,13 @@ struct QuantGemmKernel const auto& a_pad_view = [&]() { if constexpr(std::is_same_v) { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { - static_assert(packing_factor == 1); return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), @@ -458,15 +452,13 @@ struct QuantGemmKernel const auto& a_block_window = [&]() { if constexpr(std::is_same_v) { - return make_tile_window( - a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); } else { - static_assert(packing_factor == 1); return make_tile_window(a_pad_view, make_tuple(number{}, number{}), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp index 095275e60b0..180bc77ecaa 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp @@ -10,9 +10,11 @@ namespace ck_tile { -struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmABQuantPipelineAgBgCrDefaultPolicy : + // public UniversalGemmPipelineAgBgCrPolicy + public UniversalGemmBasePolicy { - using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base = UniversalGemmBasePolicy; using Base::I0; using Base::I1; using Base::I2; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cd70c2ca862..5798b44e4e0 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using AQuantGroupSize = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3())>; + // A/B DataType gets converted from PkInt4/PkFp8 during loading + using OverrideADataType = BlockGemm::OverrideADataType; + using OverrideBDataType = BlockGemm::OverrideBDataType; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -277,9 +278,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -299,9 +300,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = decltype(make_static_distributed_tensor(AQBlockTileDistr{})); using BQBlockTile = @@ -354,7 +355,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -366,7 +367,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -402,7 +403,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -413,7 +415,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); @@ -486,7 +488,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ADataType gets converted during loading from PkInt4/PkFp4 + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -536,9 +539,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, m, @@ -586,9 +589,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: ADataType PkInt4/PkFp4 gets converted during loading + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - // Note: BDataType PkInt4 gets converted during loading + // Note: BDataType PkInt4/PkFp4 gets converted during loading [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, From d0cd610a6ddb2831525b1de014d248cc2ce492b5 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Thu, 15 Jan 2026 06:43:02 +0000 Subject: [PATCH 04/14] feat: add support for fp4 CPU verification in abquant --- .../run_gemm_quant_example.inc | 25 ++++++++++++++++--- include/ck_tile/core/numeric/pk_int4.hpp | 11 ++++++++ .../ck_tile/host/reference/reference_gemm.hpp | 14 ++++++----- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 639531b0997..4e94b96183e 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -569,8 +569,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -610,14 +609,32 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( a_m_k); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-6.0f, 6.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + + if constexpr(std::is_same_v) + { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-6.0f, 6.0f, fill_seed(gen)}( + b_k_n); + } else { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( @@ -818,6 +835,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if(arg_parser.get_int("v") == 1) { + std::cout << "Performing CPU verification..." << std::endl; + ck_tile::HostTensor c_m_n_host_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 13a43f8b5c9..9365e65fed7 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/random.hpp" #include @@ -23,6 +24,11 @@ struct pk_int4_t type data; CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} + + // NOTE: added for interface compatibility with pk_fp4_t + // Other data types could be added for greater similarity + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const; + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } }; // limits @@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const +{ + return pk_int4_t_to_fp32x2_t(*this); +} + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264c..061587cee46 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -148,10 +148,11 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // ---- A loader: dequant A(m,k) into AccDataType ---- auto load_a = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); + const ADataType pk_val = a_element_op(a_m_k(m, k)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); return (k & 1) ? fp32_val.hi : fp32_val.lo; } else @@ -162,10 +163,11 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // ---- B loader: dequant B(k,n) into AccDataType ---- auto load_b = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); + const BDataType pk_val = b_element_op(b_k_n(k, n)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); return (k & 1) ? fp32_val.hi : fp32_val.lo; } else if constexpr(std::is_same_v) From 58088a516e324ddbde7c43dd2b8fb1e102f0c78b Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Thu, 15 Jan 2026 07:29:01 +0000 Subject: [PATCH 05/14] chore: add time tracking to reference calculation --- .../run_gemm_quant_example.inc | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 4e94b96183e..73db1369f53 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -841,6 +841,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); + // Track start time for reference operation + auto start_reference_tick = std::chrono::high_resolution_clock::now(); if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( @@ -914,6 +919,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); + // "Stop" our timer + auto verification_finished_tick = std::chrono::high_resolution_clock::now(); + if(!pass) { std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) @@ -921,6 +929,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, << std::endl; } std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + // Calculate and display reference timing + using DurationType = std::chrono::duration; + double reference_sec = std::chrono::duration_cast(verification_finished_tick - + start_reference_tick) + .count(); + double verification_sec = std::chrono::duration_cast( + verification_finished_tick - start_verification_tick) + .count(); + float reference_msec = static_cast(reference_sec * 1e3); + float verification_msec = static_cast(verification_sec * 1e3); + + std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took " + << reference_msec << "ms, verification took " << verification_msec << "ms." + << std::endl; } else if(arg_parser.get_int("v") == 2) { From 3d8bfdb507d9e445d10fac3631434c675374b29c Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 16 Jan 2026 16:28:32 +0000 Subject: [PATCH 06/14] feat: add a4w4 test for blockscale gemm --- .../run_gemm_quant_example.inc | 13 ++-- .../core/arch/amd_buffer_addressing.hpp | 3 +- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 36 +--------- .../pipeline/gemm_quant_pipeline_problem.hpp | 71 +++++++++++++++---- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 9 +-- test/ck_tile/gemm_block_scale/CMakeLists.txt | 5 ++ .../test_gemm_quant_abquant_a4w4.cpp | 48 +++++++++++++ .../gemm_block_scale/test_gemm_quant_base.hpp | 2 +- .../test_gemm_quant_fixtures.hpp | 4 +- 9 files changed, 124 insertions(+), 67 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 73db1369f53..d969f33063c 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -35,13 +35,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str { static_assert(std::is_same_v); - using CandidateComputeDataType = - std::conditional_t; - - using ComputeDataType = ck_tile::detail::compute_type_t; + // Use automatically determined compute type from + using ComputeDataType = void; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -191,8 +186,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str } using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, typename TypeConfig::AccDataType, typename TypeConfig::CDataType, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558add..8f9dd30bda6 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 4f3f03cd5ad..631ce56bcab 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -12,36 +12,6 @@ namespace ck_tile { -// Utility functions for dealing with packed data types -namespace detail { -template -struct is_packed_type -{ - static constexpr bool value = numeric_traits::PackedSize > 1; -}; - -template -constexpr bool is_packed_type_v = is_packed_type::value; - -template -struct compute_type -{ - using type = std::conditional_t, FallbackType, DataType>; -}; - -template -using compute_type_t = typename compute_type::type; - -template -struct gemm_type -{ - using type = compute_type_t>; -}; - -template -using gemm_type_t = typename gemm_type::type; -} // namespace detail - // A is block window on shared memory // AQ (scale tensor) is block distributed tensor. // BQ (scale tensor) is block distributed tensor. @@ -166,10 +136,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using CDataType = remove_cvref_t; // A/B DataType get converted from PkInt4/PkFp8 during loading - using OverrideADataType = - detail::gemm_type_t; - using OverrideBDataType = - detail::gemm_type_t; + using OverrideADataType = ComputeDataType; + using OverrideBDataType = ComputeDataType; using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index c8acb785cf6..9b72cf10212 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -11,6 +11,47 @@ namespace ck_tile { +namespace detail { +template +struct t +{ + using type = T; +}; + +// Helper method to automatically determine compute type +// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. +template +struct compute_type +{ + static constexpr auto Resolve() + { + if constexpr(std::is_void_v) + { + constexpr bool AIsLarger = sizeof(ADataType) >= sizeof(BDataType); + using LargestInputType = std::conditional_t; + if constexpr(numeric_traits::PackedSize > 1) + { + return t{}; + } + else + { + return t{}; + } + } + else + { + // If there's an explicitly defined compute type, use that + return t{}; + } + } + + using type = typename decltype(Resolve())::type; +}; + +template +using compute_type_t = compute_type::type; +}; // namespace detail + template -struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase +struct GemmQuantPipelineProblemBase + : public GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + detail::compute_type_t> { - using Base = GemmPipelineProblemBase; + + using Base = + GemmPipelineProblemBase>; using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index 80e41cad458..753202bf7ba 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -93,13 +93,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - - using WarpGemm = WarpGemmDispatcher +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple + //std::tuple, + //std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 8c9955da749..da06b098f7e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -207,7 +207,7 @@ template <> struct QuantTypeTraits { template - using ComputeDataType = BDataType; // For AQuant, compute type is BDataType + using ComputeDataType = void; // Use automatically determined compute type static constexpr const char* name = "abquant"; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 3798cc44430..450283bf744 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -931,8 +931,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase>; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, From 761ba1b32565d8f1519d3bc7fd7e2a92f1850be6 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Mon, 19 Jan 2026 08:17:33 +0000 Subject: [PATCH 07/14] feat: optimize reference calculation by preconverting values to AccType --- .../ck_tile/host/reference/reference_gemm.hpp | 73 ++++++++++--------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 061587cee46..f021e1a85f3 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -141,45 +141,46 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); + // Pre-convert A/B tensors to AccData type + // This prevents doing slow reconversions for each row/column + HostTensor a_acc(a_m_k.mDesc); + HostTensor b_acc(b_k_n.mDesc); + + a_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const ADataType pk_val = a_element_op(a_m_k(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[1] & 1) ? fp32_val.hi : fp32_val.lo; + } + else + { + self(index) = ck_tile::type_convert(a_element_op(a_m_k(index))); + } + }); + + b_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const BDataType pk_val = b_element_op(b_k_n(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[0] & 1) ? fp32_val.hi : fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + self(index) = fp8_to_float_raw(b_element_op(b_k_n(index))); + } + else + { + self(index) = ck_tile::type_convert(b_element_op(b_k_n(index))); + } + }); + auto f_mn = [&](auto m, auto n) { AccDataType v_acc = 0; constexpr std::size_t kGroupK = BQuantGroupSize::kK; - // ---- A loader: dequant A(m,k) into AccDataType ---- - auto load_a = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v || - std::is_same_v) - { - const ADataType pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_val.to_fp32x2(); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else - { - return ck_tile::type_convert(a_element_op(a_m_k(m, k))); - } - }; - - // ---- B loader: dequant B(k,n) into AccDataType ---- - auto load_b = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v || - std::is_same_v) - { - const BDataType pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_val.to_fp32x2(); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else if constexpr(std::is_same_v) - { - return fp8_to_float_raw(b_element_op(b_k_n(k, n))); - } - else - { - return ck_tile::type_convert(b_element_op(b_k_n(k, n))); - } - }; - // ---- a scale loader for a given K-group index ---- auto load_scale_a = [&](ck_tile::index_t k_group) -> float { const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM; @@ -226,8 +227,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // unscaled accumulation within this K-group for(std::size_t k = k_begin; k < k_end; ++k) { - const AccDataType v_a = load_a(k); - const AccDataType v_b = load_b(k); + const AccDataType v_a = a_acc(m, k); + const AccDataType v_b = b_acc(k, n); v_block_acc += v_a * v_b; } From a477fb8422ff29f2ac678c2101cc1da25d2d2af8 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Mon, 19 Jan 2026 10:23:49 +0000 Subject: [PATCH 08/14] feat: add fp4 to fp8 look-up table --- include/ck_tile/core/numeric/pk_fp4.hpp | 34 +++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index 0120c755607..f6884561ca4 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -156,6 +156,26 @@ struct pk_float4_e2m1_t bit_cast(static_cast(0xC400)), // -4 bit_cast(static_cast(0xC600)) // -6 }; + + // FP8 = E4M3. Finite and normal values should be bit-compatible between FNUZ and OCP + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x30)), // 0.5 + fp8_t(static_cast(0x38)), // 1 + fp8_t(static_cast(0x3C)), // 1.5 + fp8_t(static_cast(0x40)), // 2 + fp8_t(static_cast(0x44)), // 3 + fp8_t(static_cast(0x48)), // 4 + fp8_t(static_cast(0x4C)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB0)), // -0.5 + fp8_t(static_cast(0xB8)), // -1 + fp8_t(static_cast(0xBC)), // -1.5 + fp8_t(static_cast(0xC0)), // -2 + fp8_t(static_cast(0xC4)), // -3 + fp8_t(static_cast(0xC8)), // -4 + fp8_t(static_cast(0xCC)) // -6 + }; #endif }; @@ -419,7 +439,6 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } - CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const { #if CK_TILE_FP4_CVT_DEVICE @@ -444,7 +463,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, + e2m1_to_fp32_table[_unpack(number<1>{})] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { @@ -457,6 +477,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * scale)}; } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + return type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale; +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + return fp8x2_t{ + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)}; +} #endif } // namespace ck_tile From 72a94bdaf9177314aca3ab20dcf0a932da928002 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Mon, 19 Jan 2026 10:50:58 +0000 Subject: [PATCH 09/14] fix: reference to wrong ComputeDataType field in QuantProblem --- .../gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index 753202bf7ba..d4fcc378daa 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -93,8 +93,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmDispatcher Date: Tue, 20 Jan 2026 07:09:13 +0000 Subject: [PATCH 10/14] feat: type utilities for determining MFMA compute types --- include/ck_tile/core.hpp | 1 + .../core/utility/mfma_compute_types.hpp | 65 +++++++++++++++++++ include/ck_tile/core/utility/type_traits.hpp | 17 +++++ 3 files changed, 83 insertions(+) create mode 100644 include/ck_tile/core/utility/mfma_compute_types.hpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 01e1d00b591..7a0d9dd3fcc 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -91,6 +91,7 @@ #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/mfma_compute_types.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/print.hpp" #include "ck_tile/core/utility/random.hpp" diff --git a/include/ck_tile/core/utility/mfma_compute_types.hpp b/include/ck_tile/core/utility/mfma_compute_types.hpp new file mode 100644 index 00000000000..540793850c8 --- /dev/null +++ b/include/ck_tile/core/utility/mfma_compute_types.hpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + +#include + +namespace ck_tile { + +namespace detail { +template +struct t +{ + using type = T; +}; + +// Helper method to automatically determine compute type +// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. +template +struct auto_compute_type +{ + static constexpr auto Resolve() + { + using LargestInputType = largest_type_t; + if constexpr(is_packed_type_v) + { + return t{}; + } + else + { + return t{}; + } + } + + using type = typename decltype(Resolve())::type; +}; + +// Helper method to determine compute type, defaulting an explicitly passed-in compute type +template +struct mfma_compute_type +{ + using type = std::conditional_t, + typename auto_compute_type::type, + ComputeDataType>; +}; + +}; // namespace detail + +template +using mfma_compute_type_t = + typename detail::mfma_compute_type::type; + +// Helper method to determine compute type, defaulting to input data type +// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, +// ComputeDataType is used. +template +using mfma_compute_type_from_input_t = std::conditional_t< + is_packed_type_v, + std::conditional_t, ComputeDataType, OtherDataType>, + ThisDataType>; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f07e25e19cb..c11d180839b 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + #include #include #include @@ -187,4 +189,19 @@ template using tuple_element_or_default_t = typename tuple_element_or_default::type; +// Helper struct to determine if a type is packed (more than 1 element per byte) +template +struct is_packed_type +{ + static constexpr bool value = numeric_traits::PackedSize > 1; +}; + +template +static constexpr bool is_packed_type_v = is_packed_type::value; + +// Helper definition to take the largest sizes type +template +using largest_type_t = + std::conditional_t= sizeof(BDataType), ADataType, BDataType>; + } // namespace ck_tile From 37af2173e32c6dc56c3716f204133af476e42748 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 20 Jan 2026 07:10:37 +0000 Subject: [PATCH 11/14] feat: packed fp4 for abquant weight preshuffle --- .../ops/common/load_interleaved_pk_type.hpp | 2 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 20 +++++-- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 15 +++-- .../pipeline/gemm_quant_pipeline_problem.hpp | 56 +++---------------- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 52 +++++++++-------- .../test_gemm_quant_abquant_a4w4.cpp | 5 +- 6 files changed, 67 insertions(+), 83 deletions(-) diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 4743d572f72..fe16fe1418c 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -42,7 +42,7 @@ template CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v || std::is_same_v) + if constexpr(numeric_traits::PackedSize > 1) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); InterleavedPKTypeLoader::load_interleaved_pk_type( diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157cb..f922e95d372 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" @@ -252,11 +253,20 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - using WarpGemm = WarpGemmDispatcher; + using BTypeToUse = mfma_compute_type_from_input_t; + + using WarpGemm = WarpGemmDispatcher f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -188,7 +190,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg typename BFlatBlockTensor, typename AQBlockTensor, typename BQBlockTensor, - typename ABlockWindow> + typename ABlockWindow, + index_t UnaryOpSize = 8> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, ABlockTensor& a_warp_tensor, BFlatBlockTensor& b_warp_tensor, @@ -232,8 +235,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); + + load_int4_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); } // barrier // Could be deleted diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 38304a01a95..28d7bcb2d02 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -11,47 +11,6 @@ namespace ck_tile { -namespace detail { -template -struct t -{ - using type = T; -}; - -// Helper method to automatically determine compute type -// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. -template -struct compute_type -{ - static constexpr auto Resolve() - { - if constexpr(std::is_void_v) - { - constexpr bool AIsLarger = sizeof(ADataType) >= sizeof(BDataType); - using LargestInputType = std::conditional_t; - if constexpr(numeric_traits::PackedSize > 1) - { - return t{}; - } - else - { - return t{}; - } - } - else - { - // If there's an explicitly defined compute type, use that - return t{}; - } - } - - using type = typename decltype(Resolve())::type; -}; - -template -using compute_type_t = compute_type::type; -}; // namespace detail - template struct GemmQuantPipelineProblemBase - : public GemmPipelineProblemBase< - ADataType_, - BDataType_, - CDataType_, - BlockGemmShape_, - Traits_, - detail::compute_type_t> + : public GemmPipelineProblemBase> { using Base = @@ -82,7 +40,7 @@ struct GemmQuantPipelineProblemBase CDataType_, BlockGemmShape_, Traits_, - detail::compute_type_t>; + mfma_compute_type_t>; using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 0f3951ffccc..f99e52a5a84 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -7,6 +7,8 @@ #include #include "ck_tile/core.hpp" +#include "ck_tile/core/utility/mfma_compute_types.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -234,36 +236,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load + auto a_dram_tile_distribution = + PipelinePolicy::template MakeADramTileDistribution(); + auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); // ping-pong window for A LDS + auto a_warp_tile_distribution = + make_static_tile_distribution(typename WG::AWarpDstrEncoding{}); + auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); auto a_warp_window_pong_tmp = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); statically_indexed_array< statically_indexed_array, @@ -308,8 +316,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe b_flat_dram_block_window_tmp.get_window_origin(), b_flat_distribution); - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + using BTypeToUse = mfma_compute_type_from_input_t; using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); // pingpong buffer for B @@ -349,7 +356,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -389,15 +396,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; + using ATypeToUse = mfma_compute_type_from_input_t; + using ATileType = + decltype(make_static_distributed_tensor(a_warp_tile_distribution)); + statically_indexed_array a_warp_tensor; static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -430,7 +438,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -442,8 +450,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // Next K @@ -455,7 +463,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -485,8 +493,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; HotLoopScheduler(); @@ -503,7 +511,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -525,8 +533,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // GEMM loopK diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp index 836320aee26..5cdeaaf4d57 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp @@ -32,9 +32,12 @@ using GroupSize2D = ck_tile::QuantGroupShape>; // clang-format off using ABQuantTypes = ::testing::Types< // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple + std::tuple, //std::tuple, //std::tuple + + // PreshuffleB = true && TransposeC = false + std::tuple >; // clang-format on From 32d5757ce47da4c122878eb8886ef52de362710b Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 20 Jan 2026 11:22:45 +0000 Subject: [PATCH 12/14] feat: add separate tests for a4w4 base case, padding and preshuffleB --- test/ck_tile/gemm_block_scale/CMakeLists.txt | 21 ++++-- .../test_gemm_quant_abquant_a4w4_base.cpp | 44 +++++++++++++ .../test_gemm_quant_abquant_a4w4_padding.cpp | 65 +++++++++++++++++++ ...st_gemm_quant_abquant_a4w4_preshuffle.cpp} | 9 +-- 4 files changed, 126 insertions(+), 13 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp rename test/ck_tile/gemm_block_scale/{test_gemm_quant_abquant_a4w4.cpp => test_gemm_quant_abquant_a4w4_preshuffle.cpp} (67%) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index bfe83ad8ca0..0a15dc32ab8 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -34,11 +34,6 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_tile_gemm_quant_abquant_a4w4 - test_gemm_quant_abquant_a4w4.cpp - ) - target_compile_options(test_tile_gemm_quant_abquant_a4w4 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_tile_gemm_quant_abquant_padding test_gemm_quant_abquant_padding.cpp ) @@ -49,6 +44,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base + test_gemm_quant_abquant_a4w4_base.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding + test_gemm_quant_abquant_a4w4_padding.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle + test_gemm_quant_abquant_a4w4_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # AQuant tests add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp new file mode 100644 index 00000000000..5e2403f7d17 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp new file mode 100644 index 00000000000..d126f77ffbd --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK) +{ + this->run_test_with_validation(1024, 1024, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN) +{ + this->run_test_with_validation(1024, 832, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM) +{ + this->run_test_with_validation(832, 1024, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK) +{ + this->run_test_with_validation(832, 832, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp similarity index 67% rename from test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp rename to test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp index 5cdeaaf4d57..43051c8d088 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp @@ -12,10 +12,7 @@ // Type aliases for readability using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; using PkFP4 = ck_tile::pk_fp4_t; using ABQuantGrouped = std::integral_constant; @@ -31,11 +28,7 @@ using GroupSize2D = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout> // clang-format off using ABQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - //std::tuple, - //std::tuple - + // RCR layout with RowMajor AQ, ColumnMajor BQ // PreshuffleB = true && TransposeC = false std::tuple >; From f55d9025a5c7340948dbd53dfc00452c3fa46538 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 20 Jan 2026 12:15:17 +0000 Subject: [PATCH 13/14] fix: fp4 conversion on gfx950 attempting to use non-supported method --- include/ck_tile/core/numeric/pk_fp4.hpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index f6884561ca4..f79bbd1be9a 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -441,20 +441,24 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const } CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const { -#if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data, scale); -#else + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else return fp8_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; -#endif + // #endif } CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const { -#if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data, scale); -#else + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else return fp8x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), type_convert(convert_to_float(_unpack(number<1>{}), scale))}; -#endif + // #endif } #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const From c0d869b4e505451d143de599e91d7394ff6628fe Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 20 Jan 2026 12:59:57 +0000 Subject: [PATCH 14/14] fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size --- .../gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp index d126f77ffbd..1e496d5b642 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp @@ -18,10 +18,10 @@ using ABQuantGrouped = std::integral_constant; // 1d block sizes for AQuant -using GroupSize1D = ck_tile::QuantGroupShape>; +using GroupSize1D = ck_tile::QuantGroupShape>; // 2d block sizes for BQuant -using GroupSize2D = ck_tile::QuantGroupShape>; +using GroupSize2D = ck_tile::QuantGroupShape>; // Type combinations for ABQuant tests // Tuple format: