diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 7451ee25b02..d77e3c93229 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -21,7 +21,6 @@ if(has_supported_gpu) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") - add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index d6c84f3064b..1141717545c 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor& src) const int K = src_lengths[0]; const int N = src_lengths[1]; constexpr int packed_size = ck_tile::numeric_traits::PackedSize; - int KPack = 16 * packed_size; // fp4:32 or fp8:16 - int NLane = N_Warp_Tile; - int KLane = 64 / NLane; - int K0 = K / (KLane * KPack); + int KPack = + std::is_same_v ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16 + int NLane = N_Warp_Tile; + int KLane = 64 / NLane; + int K0 = K / (KLane * KPack); ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); @@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[]) } else if(mx_prec == "fp6" || mx_prec == "fp6xfp6") { - throw std::runtime_error("fp6xfp6 is not supported."); + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); } else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index 0b6185590fa..d4922bb44c7 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct MXfp6_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + struct MXfp8_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake index 5e86cd71332..9250dbe7ae6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST) set(C_LAYOUT ROW) set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16") + set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16") set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16") # foreach(PERSISTENT false true) # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. foreach(PERSISTENT false) - foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8) + foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8) set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}}) string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE}) list(GET DATA_TYPE_AB 0 A_DATA_TYPE) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index 9675d3345b2..e6d612f0d6e 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -19,6 +19,7 @@ using FP4 = ck_tile::pk_fp4_t; using FP8 = ck_tile::fp8_t; +using FP6 = ck_tile::pk_fp6x16_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t; diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index b4d1fe237b3..54c23e22662 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc, M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); - - if(init_method == 0) - { - ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); - } - else if(init_method == 1) + if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); + auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); + std::vector random_bufA(a_buffer_bytes); + std::vector random_bufB(b_buffer_bytes); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(1, 4); + + for(size_t i = 0; i < a_buffer_bytes; ++i) + random_bufA[i] = static_cast(dis(gen)); + + for(size_t i = 0; i < b_buffer_bytes; ++i) + random_bufB[i] = static_cast(dis(gen)); + + memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); + memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); } else { - throw std::runtime_error("wrong! Unexpected init_method"); + if(init_method == 0) + { + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + } + else + { + throw std::runtime_error("wrong! Unexpected init_method"); + } } const auto b_shuffled_host = preShuffleWeight(b_origin_host); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index f8b1736801c..ef129f9bc08 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -639,7 +639,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, reg_c.template AsType()[Number<0>{}], - 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 ; 3 FP6 E3M2; 4 FP4 E2M1} 2, // blgp 0, // OPSEL 0, diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 01e1d00b591..f43254a31f1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,6 +54,7 @@ #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558add..057e5b20abd 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b5..08aaab4eb96 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1129,6 +1129,23 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32"); + +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3( + dwordx3_union vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) +{ + int32x3_t v_reg; + v_reg[0] = vdata.as_i32[0]; + v_reg[1] = vdata.as_i32[1]; + v_reg[2] = vdata.as_i32[2]; + llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, glc_slc); +}; + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, @@ -1285,7 +1302,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 | N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1325,6 +1342,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + dwordx3_union ret; + ret.as_i32[0] = tmp[0]; + ret.as_i32[1] = tmp[1]; + ret.as_i32[2] = tmp[2]; + return bit_cast(ret); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, @@ -1406,7 +1435,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && @@ -1414,7 +1446,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))) || + (std::is_same::value && (N == 1)), "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1745,7 +1778,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer(coherence)); } + else if constexpr(N == 12) + { + llvm_amdgcn_raw_buffer_store_i32x3(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } else if constexpr(N == 16) { llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), @@ -1854,10 +1895,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + std::is_same::value && (N == 1), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp new file mode 100644 index 00000000000..146d294b643 --- /dev/null +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -0,0 +1,110 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +namespace ck_tile { +template +struct pk_fp6_t +{ + static constexpr index_t num_bits_elem = 6; + using element_type = int32_t; // element storage fundamental type + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_vec_elem = + sizeof(element_type) * 8; // 32-bit uint for storage + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + element_type data_[vector_size]; // packed data + using type = pk_fp6_t; + CK_TILE_HOST_DEVICE constexpr pk_fp6_t(){}; + CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value) + { + for(size_t i = 0; i < vector_size; ++i) + { + data_[i] = value; + } + } + void pack(const int32_t x, const index_t i) + { + int32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + int32_t old_value = data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) + { + int32_t next_value = data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data_[arr_index + 1] = next_value; + } + } + + template + static inline int32_t unpack(const T& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + int32_t bits = pk.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); + } + + return bits & 0x3F; + } + + inline int32_t unpack(const index_t i) const { return unpack(*this, i); } + + CK_TILE_HOST_DEVICE int32_t operator[](index_t i) const { return data_[i]; } + + static float fp6_e2m3_to_float(int32_t fp6_bits) + { + fp6_bits = fp6_bits & 0x3F; + + uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5 + uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3 + uint32_t mantissa = fp6_bits & 0x7; // bits 2-0 + + float result; + if(exponent == 0 && mantissa == 0) + { + result = 0.f; + } + else if(exponent != 0) + { + result = std::pow(2, exponent - 1); + float mantissa_value = 1.0f + mantissa / 8.0f; + result *= mantissa_value; + } + else + { + result = mantissa / 8.0f; + } + return sign == 1 ? -1 * result : result; + } +}; + +using pk_fp6x16_t = pk_fp6_t<16>; +using pk_fp6x32_t = pk_fp6_t<32>; +template <> +struct numeric_traits +{ + static constexpr int PackedSize = 16; +}; +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index deaa9e0bd90..634b8457258 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -72,6 +72,7 @@ CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56ea..29178e956e5 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -159,6 +159,40 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x64_t = int32_t __attribute__((ext_vector_type(64))); +struct int32x3_tt +{ + int32_t data[3]; +}; + +struct int32x6_tt +{ + int32_t data[6]; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 12; + using value_type = int32x3_tt; + using type = int32x3_tt; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 1; + using value_type = int32x3_tt; + using type = int32x3_tt; // this is danguous +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 2; + using value_type = int32x6_tt; + using type = int32x6_tt; // this is danguous +}; + // u32 // using uint32_t = ... using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index f3aeed6e614..59f82939b99 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -303,7 +303,6 @@ struct buffer_view>::scalar_type, - scalar_per_t_vector * scalar_per_x_vector>; - // using buf_t = ushort __attribute__((ext_vector_type(8))); - auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); - return bit_cast(rtn); + constexpr index_t load_elts = scalar_per_t_vector * scalar_per_x_vector; + if constexpr(load_elts == 12 && sizeof(typename X::value_type) == 1) + { + auto rtn = reinterpret_cast(p_data_) + (i + linear_offset) / 4; + struct + { + int32_t x, y, z; + } tmp = {rtn[0], rtn[1], rtn[2]}; + return bit_cast(tmp); + } + else + { + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); + return bit_cast(rtn); + } #endif } else @@ -968,6 +979,7 @@ struct buffer_view, int8x16_t> && std::is_same_v, int8x16_t>) || // int8 on thread buffer (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || @@ -1033,6 +1045,11 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else if constexpr(std::is_same_v, thread_buffer>) + { + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } else if constexpr((std::is_same_v, int8_t> && std::is_same_v, int8x16_t>) || (std::is_same_v, int8_t> && @@ -1075,6 +1092,12 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else + { + static_assert(false, + "wrong! not implemented for this combination, please add " + "implementation"); + } } } else diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a1be8027b28..9d66764b4ff 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -720,4 +720,56 @@ std::enable_if_t<(std::is_same_v, ranges::range_val return err_count == 0; } +/** + * @brief Check errors between pk_fp6x16_t ranges + * + * Compares two ranges of pk_fp6x16_t without tolerance. + * This specialization handles ck_tile::pk_fp6x16_t type. + * + * @tparam Range Type of output range + * @tparam RefRange Type of reference range + * @param out Output range to check + * @param ref Reference range to check against + * @param msg Error message to display if check fails + * @return True if check passes, false otherwise + */ +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, pk_fp6x16_t>), + bool> + CK_TILE_HOST check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double = 0) +{ + if(check_size_mismatch(out, ref, msg)) + return false; + + int err_count = 0; + + auto update_err = [&](float o, float r, std::size_t index) { + if(std::fabs(o - r) > 1e-8) + { + std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o + << " != " << r << std::endl; + ++err_count; + } + }; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const pk_fp6x16_t o = *std::next(std::begin(out), i); + const pk_fp6x16_t r = *std::next(std::begin(ref), i); + for(std::size_t j = 0; j < numeric_traits::PackedSize; j++) + { + update_err(o.unpack(j), r.unpack(j), i * numeric_traits::PackedSize + j); + } + } + if(err_count > 0) + { + report_error_stats(err_count, numeric::max(), ref.size()); + } + return err_count == 0; +} + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264c..01b3ab5df1e 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -617,6 +617,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, a_m_k_scaled(m, k) = a_f4_lo * a_scale; a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto a_scale = ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + a_m_k_scaled(m, k + k_) = + pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale; + } + } else { a_m_k_scaled(m, k) = @@ -645,6 +656,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, b_k_n_scaled(k, n) = b_f4_lo * b_scale; b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto b_scale = ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + b_k_n_scaled(k + k_, n) = + pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale; + } + } else { b_k_n_scaled(k, n) = diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 425083a9de3..4a30e3af163 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -22,6 +22,7 @@ template <> struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp6x16"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index bc7d2323d05..23d7a9fca99 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -118,8 +118,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 + ? 16 + : 16 /*dwordx4*/ * APackedSize / sizeof(ADataType); + static constexpr index_t BK1 = std::is_same_v + ? 16 + : 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -537,24 +542,26 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_store_lds_window_pong = make_tile_window( // a_lds_block_pong, - make_tuple(number{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); - auto a_warp_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_ping = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_pong = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); // B flat DRAM window for load @@ -621,7 +628,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { @@ -663,7 +670,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> @@ -683,7 +690,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); __builtin_amdgcn_sched_barrier(0); @@ -750,7 +758,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -760,7 +768,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); @@ -839,7 +848,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_pong, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished @@ -849,7 +858,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); }; @@ -874,7 +884,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 0); } - // TAIL if constexpr(TailNum == TailNumber::Even) { @@ -933,7 +942,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -947,7 +956,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); Last2ndHotLoopScheduler(); @@ -977,12 +987,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_pong, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, + number>{}); } }); LastHotLoopScheduler(); @@ -1014,12 +1024,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_ping, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, + number>{}); } }); LastHotLoopScheduler(); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index c4ab1d4a78c..1d32b33bc96 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -17,6 +17,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr index_t kDramLoadPackBytes = 128; static constexpr index_t DWORDx4 = 16; + static constexpr index_t DWORDx3 = 12; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; @@ -77,15 +78,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution() { - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + constexpr index_t K0 = + KPerBlock / APackedSize * sizeof(ADataType) / (K1 * K2); // KPerBlock/256/packsize constexpr index_t M2 = WaveSize / K1; // 8 constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + static_assert(K0 * K1 * K2 == KPerBlock / APackedSize * sizeof(ADataType), "K0, K1, K2 must cover whole KPerBlock!"); return make_static_tile_distribution( @@ -107,9 +109,9 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - const index_t K0 = cols / (K1 * K2 * APackedSize); + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + const index_t K0 = cols / (K1 * K2 / sizeof(ADataType) * APackedSize); const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds @@ -138,19 +140,23 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); - auto&& origin_tmp = window_tmp.get_window_origin(); + auto&& origin_tmp = window_tmp.get_window_origin(); + constexpr index_t test1 = APackedSize / sizeof(ADataType); return make_tile_window(byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / APackedSize}, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / test1}, MakeMX_ABytesDramTileDistribution()); } CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor() { - constexpr index_t K2 = AK1 / APackedSize; // 16 - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256 - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + constexpr index_t K2 = std::is_same_v ? DWORDx3 : AK1 / APackedSize; + constexpr index_t K2_Pad = 16; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + constexpr index_t K0 = std::is_same_v + ? KPerBlock / (K1 * K2 / sizeof(ADataType) * APackedSize) + : KPerBlock / (K1 * AK1); // KPerBlock/256 + static_assert(K0 * K1 * K2 / sizeof(ADataType) * APackedSize == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); constexpr index_t M3 = 4; // so that we can use imm offset to load lds @@ -169,12 +175,12 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - number{}, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, number<1>{}), number{}, number<1>{}); @@ -216,7 +222,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - if constexpr(K_Thread == AK1) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -225,7 +231,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<0, 2>>, sequence<2>, sequence<1>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -235,6 +241,19 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<1, 2>>, sequence<2, 2>, sequence<0, 2>>{}); + else if constexpr(std::is_same_v) + // K_Lane=4, K_Thread=32 + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<2, 2>, + sequence<1, 2>>{}); + else + static_assert(false, "unsupported datatype"); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() @@ -245,17 +264,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - if constexpr(BK1 == K_Thread) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 16 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -265,6 +284,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>{}); + else if constexpr(std::is_same_v) + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 64 1 2 12 + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<2, 2>, + sequence<2, 3>>{}); + else + static_assert(false, "unsupported datatype"); } template @@ -280,21 +314,27 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple( - flat_n, flat_k / flat_k_per_block, number{})), + make_naive_tensor_descriptor_packed( + make_tuple(flat_n, + flat_k / flat_k_per_block, + number{})), make_tuple(make_pass_through_transform(flat_n), make_merge_transform_v3_division_mod(make_tuple( - flat_k / flat_k_per_block, number{}))), + flat_k / flat_k_per_block, + number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, byte_tensor_desc); auto&& origin_tmp = window_tmp.get_window_origin(); + auto origin_n = origin_tmp[0]; + auto origin_k = static_cast(origin_tmp[1] * sizeof(BDataType) / BPackedSize); return make_tile_window( byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / BPackedSize}, + make_tuple(number{}, + number{}), + {origin_n, origin_k}, MakeMX_BFlatBytesDramTileDistribution()); } @@ -372,7 +412,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + if constexpr(!std::is_same_v) + { + return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } + else + { + return MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index bd65f533839..9b7c3743f7b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1570,7 +1570,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 return make_tuple(number<0>{}, int32x8_t{}); else if constexpr(std::is_same_v) return make_tuple(number<1>{}, int32x8_t{}); - // else if e2m3 => make_tuple(number<2>{}, int32x6_t{}) + else if constexpr(std::is_same_v) + return make_tuple(number<2>{}, pk_fp6x32_t{}); // else if e3m2 => make_tuple(number<3>{}, int32x6_t{}) else if constexpr(std::is_same_v) return make_tuple(number<4>{}, int32x4_t{}); diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index 2a43b596e4c..208b92e7026 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -20,6 +20,25 @@ struct MemoryCopyParam ck_tile::index_t warp_id; }; +template +struct type_list +{ +}; + +template +struct type_at; + +template +struct type_at> : type_at> +{ +}; + +template +struct type_at<0, type_list> +{ + using type = Head; +}; + template class TestCkTileMemoryCopy : public ::testing::TestWithParam> { @@ -33,48 +52,47 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam ? 1 : 0; ck_tile::HostTensor x_host({m, n}); ck_tile::HostTensor y_host_dev({m, n}); + ck_tile::HostTensor host_init_buf({x_host.get_element_space_size_in_bytes()}); std::cout << "input: " << x_host.mDesc << std::endl; std::cout << "output: " << y_host_dev.mDesc << std::endl; - ck_tile::index_t value = 1; - for(int i = 0; i < m; i++) - { - value = 1; - for(int j = 0; j < n; j++) - { - value = (value + 1) % 127; - x_host(i, j) = static_cast(value); - } - } - + for(size_t i = 0; i < x_host.get_element_space_size_in_bytes(); i++) + host_init_buf.mData[i] = i % 64; + memcpy(x_host.mData.data(), + host_init_buf.mData.data(), + x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, dword_bytes / sizeof(DataType)>; + using BlockTileList = type_list, ck_tile::sequence<16, 96>>; + using VectorList = type_list, + ck_tile::sequence<1, 24>>; + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = type_at::type; + using WaveTile = type_at::type; + using Vector = type_at::type; ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; + using Problem = ck_tile::TileCopyProblem; using Kernel = ck_tile::TileCopy; constexpr ck_tile::index_t kBlockSize = 128; constexpr ck_tile::index_t kBlockPerCu = 1; + // when copy fp6x16 buffer, tread it as int8 buffer and recompute n-dim size. + ck_tile::index_t cpy_n = + CpyCfg == 1 ? n * sizeof(DataType) / + (sizeof(int8_t) * ck_tile::numeric_traits::PackedSize) + : n; auto ms = launch_kernel( ck_tile::stream_config{nullptr, true}, @@ -85,21 +103,28 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam(x_buf.GetDeviceBuffer()), static_cast(y_buf.GetDeviceBuffer()), m, - n, + cpy_n, warp_id)); - auto bytes = 2 * m * n * sizeof(DataType); + auto bytes = 2 * m * n * sizeof(DataType) / ck_tile::numeric_traits::PackedSize; std::cout << "elapsed: " << ms << " (ms)" << std::endl; std::cout << (bytes * 1e-6 / ms) << " (GB/s)" << std::endl; // reference y_buf.FromDevice(y_host_dev.mData.data()); bool pass = ck_tile::check_err(y_host_dev, x_host); - EXPECT_TRUE(pass); } }; +class TestCkTileMemoryCopyF6x16Async : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyF6x16 : public TestCkTileMemoryCopy +{ +}; + class TestCkTileMemoryCopyHalfAsync : public TestCkTileMemoryCopy { }; @@ -116,6 +141,18 @@ class TestCkTileMemoryCopyFP8Async : public TestCkTileMemoryCopy { }; +TEST_P(TestCkTileMemoryCopyF6x16, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyF6x16Async, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + TEST_P(TestCkTileMemoryCopyHalfAsync, TestCorrectness) { auto [M, N, warp_id] = GetParam(); @@ -140,6 +177,20 @@ TEST_P(TestCkTileMemoryCopyFP8Async, TestCorrectness) this->Run({M, N, warp_id}); } +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16Async, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, TestCkTileMemoryCopyHalfAsync, ::testing::Values(std::tuple{64, 8, 0}, diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index 847763881b5..2ce4982a04d 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -51,12 +51,15 @@ struct TileCopyShape "Inconsistent wave group size!"); }; -template +template struct TileCopyProblem { using XDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool AsyncCopy = AsyncCopy_; + // 0: copy 1, 2, 4 bytes data type + // 1: copy dwordx3 bytes data type + static constexpr int CpyCfg = CpyCfg_; }; template @@ -67,6 +70,7 @@ struct TileCopy static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr bool AsyncCopy = Problem::AsyncCopy; + static constexpr int CpyCfg = Problem::CpyCfg; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() @@ -98,8 +102,40 @@ struct TileCopy return make_static_tile_distribution(outer_encoding); } + template + // CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Block_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // number of active warps working in this thread block. + constexpr index_t Y2 = + warp_size / X0; // number of threads in a warp needed along M dimension. + constexpr index_t Y1 = + S::Warp_M / + Y2; // number of iterations each warp needs to perform to cover the entire tile window. + constexpr auto outer_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, // Y2==16,X0==4 + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}; + + return make_static_tile_distribution(outer_encoding); + } + CK_TILE_DEVICE void - operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + run_normal_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const { using S = typename Problem::BlockShape; @@ -170,6 +206,124 @@ struct TileCopy move_tile_window(y_block_window, {0, S::Block_N}); } } -}; + CK_TILE_DEVICE void + run_dwordx3_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + constexpr index_t X0 = S::ThreadPerWarp_N; + constexpr index_t X1 = S::Block_N; + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + // LDS buffer + constexpr int dim1_stride = + AsyncCopy ? 16 : 12; // async_load dwordx3 will write 3 bytes & skip 1 bytes in lds. + constexpr int repeat_num = X1 / (X0 * X2); + __shared__ int8_t x_lds[repeat_num * S::Block_M * X0 * dim1_stride]; + + constexpr auto block_dims = make_tuple(number{}, number{}); + constexpr auto block_dims_ = make_tuple(number{}, + number{}, + number{}, + number{}); + constexpr auto block_strides = make_tuple(number{}, + number{}, + number{}, + number<1>{}); + + const auto x_lds_desc_ = + make_naive_tensor_descriptor(block_dims_, block_strides, number<12>{}, number<1>{}); + const auto x_lds_desc = transform_tensor_descriptor( + x_lds_desc_, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number<2>{}, number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = + make_tensor_view(reinterpret_cast(x_lds), x_lds_desc); + + auto x_block_lds_write_window = make_tile_window(x_lds_view, block_dims, {0, 0}); + + auto x_block_lds_read_window = make_tile_window( + x_lds_view, block_dims, {0, 0}, MakeDwordx3DRAMDistribution()); + + const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M); + // Input tensor + const auto x_m_n = + make_naive_tensor_view(reinterpret_cast(p_x), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, block_dims, {iM, 0}, MakeDwordx3DRAMDistribution()); + + // Output tensor + const auto y_m = + make_naive_tensor_view(reinterpret_cast(p_y), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto y_block_window = make_tile_window(y_m, block_dims, {iM, 0}); + + const index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + const index_t my_id = __builtin_amdgcn_readfirstlane(get_warp_id()); + constexpr index_t async_copy_fence_cnt = 0; + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + if(my_id == warp_id) + { + if constexpr(AsyncCopy) + { + async_load_tile(x_block_lds_write_window, x_block_window); + // We don't have prefetch here, wait the data back immediately. + // Wait all asyncload insts complete. + // Wait all waves synced + s_waitcnt_barrier(); + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + else + { + // load from DRAM to registers + auto dram_tile = load_tile(x_block_window); + // store in lds + store_tile(x_block_lds_write_window, dram_tile); + // Wait all lds write insts complete + // Wait all waves synced + block_sync_lds(); + // read from lds to registers + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + } + + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } + + CK_TILE_DEVICE void + operator()(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + if constexpr(CpyCfg == 1) + { + run_dwordx3_cpy(p_x, p_y, M, N, warp_id); + } + else if constexpr(CpyCfg == 0) + { + run_normal_cpy(p_x, p_y, M, N, warp_id); + } + else + { + static_assert(false, "unsupported copy config type."); + } + } +}; } // namespace ck_tile