From f334ba5be0022450b0930cb36465be78f1ab7cbb Mon Sep 17 00:00:00 2001 From: ZheWang Date: Fri, 26 Dec 2025 03:00:18 +0000 Subject: [PATCH 01/14] add fp6 data-type and support sync/async dwordx3 load/store --- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 9 +- include/ck/utility/amd_xdlops.hpp | 2 +- include/ck_tile/core.hpp | 1 + .../core/arch/amd_buffer_addressing.hpp | 11 +- .../arch/amd_buffer_addressing_builtins.hpp | 54 +++++- include/ck_tile/core/numeric/pk_fp6.hpp | 108 ++++++++++++ include/ck_tile/core/numeric/type_convert.hpp | 1 + include/ck_tile/core/numeric/vector_type.hpp | 21 +++ include/ck_tile/core/tensor/buffer_view.hpp | 14 +- include/ck_tile/host/check_err.hpp | 52 ++++++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 3 +- test/ck_tile/memory_copy/test_copy.cpp | 101 ++++++++--- test/ck_tile/memory_copy/test_copy.hpp | 160 +++++++++++++++++- 13 files changed, 495 insertions(+), 42 deletions(-) create mode 100644 include/ck_tile/core/numeric/pk_fp6.hpp diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index d6c84f3064b..4e7a209b8f6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor& src) const int K = src_lengths[0]; const int N = src_lengths[1]; constexpr int packed_size = ck_tile::numeric_traits::PackedSize; - int KPack = 16 * packed_size; // fp4:32 or fp8:16 - int NLane = N_Warp_Tile; - int KLane = 64 / NLane; - int K0 = K / (KLane * KPack); + int KPack = + std::is_same_v ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16 + int NLane = N_Warp_Tile; + int KLane = 64 / NLane; + int K0 = K / (KLane * KPack); ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index f8b1736801c..ef129f9bc08 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -639,7 +639,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, reg_c.template AsType()[Number<0>{}], - 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 ; 3 FP6 E3M2; 4 FP4 E2M1} 2, // blgp 0, // OPSEL 0, diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 01e1d00b591..c3b0d80376a 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -53,6 +53,7 @@ #include "ck_tile/core/numeric/mxfp_convert.hpp" #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558add..04b814efe41 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b5..35a91b2ae80 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1129,6 +1129,23 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32"); + +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3( + dwordx3_union vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) +{ + int32x3_t v_reg; + v_reg[0] = vdata.as_i32[0]; + v_reg[1] = vdata.as_i32[1]; + v_reg[2] = vdata.as_i32[2]; + llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, glc_slc); +}; + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, @@ -1285,7 +1302,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 | N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1325,6 +1342,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + dwordx3_union ret; + ret.as_i32[0] = tmp[0]; + ret.as_i32[1] = tmp[1]; + ret.as_i32[2] = tmp[2]; + return bit_cast(ret); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, @@ -1406,7 +1435,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && @@ -1414,7 +1444,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))) || + (std::is_same::value && (N == 1)), "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1745,7 +1776,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer(coherence)); } + else if constexpr(N == 12) + { + llvm_amdgcn_raw_buffer_store_i32x3(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } else if constexpr(N == 16) { llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), @@ -1854,10 +1893,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + std::is_same::value && (N == 1), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp new file mode 100644 index 00000000000..1fd414c6f21 --- /dev/null +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -0,0 +1,108 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +namespace ck_tile { +template +struct pk_fp6_t +{ + static constexpr index_t num_bits_elem = 6; + using element_type = uint32_t; // element storage fundamental type + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_vec_elem = + sizeof(element_type) * 8; // 32-bit uint for storage + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + element_type data_[vector_size]; // packed data + using type = pk_fp6_t; + CK_TILE_HOST_DEVICE pk_fp6_t() {}; + CK_TILE_HOST_DEVICE explicit pk_fp6_t(int value) + { + for(size_t i = 0; i < vector_size; ++i) + { + data_[i] = value; + } + } + void pack(const uint32_t x, const index_t i) + { + uint32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data_[arr_index + 1] = next_value; + } + } + + template + static inline uint32_t unpack(const T& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + uint32_t bits = pk.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); + } + + return bits & 0x3F; + } + + inline uint32_t unpack(const index_t i) const { return unpack(*this, i); } + + float fp6_e2m3_to_float(uint32_t fp6_bits) + { + fp6_bits = fp6_bits & 0x3F; + + uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5 + uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3 + uint32_t mantissa = fp6_bits & 0x7; // bits 2-0 + + float result; + if(exponent == 0 && mantissa == 0) + { + result = 0.f; + } + else if(exponent != 0) + { + result = std::pow(2, exponent - 1); + float mantissa_value = 1.0f + mantissa / 8.0f; + result *= mantissa_value; + } + else + { + result = mantissa / 8.0f; + } + return sign == 1 ? -1 * result : result; + } +}; + +using pk_fp6x16_t = pk_fp6_t<16>; +using pk_fp6x32_t = pk_fp6_t<32>; +template <> +struct numeric_traits +{ + static constexpr int PackedSize = 16; +}; +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index deaa9e0bd90..634b8457258 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -72,6 +72,7 @@ CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56ea..9ba8b9cf3eb 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -46,6 +46,27 @@ struct ext_vector +struct ext_vector +{ + static constexpr index_t N = 12; + using value_type = int32x3_t; + using type = int32x3_t; +}; + +template +struct ext_vector +{ + static constexpr index_t N = N_; + using value_type = pk_fp6x16_t; + using type = pk_fp6x16_t; // this is danguous +}; + template struct ext_vector::type>>> { diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index f3aeed6e614..2f112961bbd 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -637,7 +637,7 @@ struct buffer_view, int32_t>; -#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT +#elif (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = std::is_same_v, float> || (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0) @@ -968,6 +968,7 @@ struct buffer_view, int8x16_t> && std::is_same_v, int8x16_t>) || // int8 on thread buffer (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || @@ -1033,6 +1034,11 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else if constexpr(std::is_same_v, thread_buffer>) + { + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } else if constexpr((std::is_same_v, int8_t> && std::is_same_v, int8x16_t>) || (std::is_same_v, int8_t> && @@ -1075,6 +1081,12 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else + { + static_assert(false, + "wrong! not implemented for this combination, please add " + "implementation"); + } } } else diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a1be8027b28..9d66764b4ff 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -720,4 +720,56 @@ std::enable_if_t<(std::is_same_v, ranges::range_val return err_count == 0; } +/** + * @brief Check errors between pk_fp6x16_t ranges + * + * Compares two ranges of pk_fp6x16_t without tolerance. + * This specialization handles ck_tile::pk_fp6x16_t type. + * + * @tparam Range Type of output range + * @tparam RefRange Type of reference range + * @param out Output range to check + * @param ref Reference range to check against + * @param msg Error message to display if check fails + * @return True if check passes, false otherwise + */ +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, pk_fp6x16_t>), + bool> + CK_TILE_HOST check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double = 0) +{ + if(check_size_mismatch(out, ref, msg)) + return false; + + int err_count = 0; + + auto update_err = [&](float o, float r, std::size_t index) { + if(std::fabs(o - r) > 1e-8) + { + std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o + << " != " << r << std::endl; + ++err_count; + } + }; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const pk_fp6x16_t o = *std::next(std::begin(out), i); + const pk_fp6x16_t r = *std::next(std::begin(ref), i); + for(std::size_t j = 0; j < numeric_traits::PackedSize; j++) + { + update_err(o.unpack(j), r.unpack(j), i * numeric_traits::PackedSize + j); + } + } + if(err_count > 0) + { + report_error_stats(err_count, numeric::max(), ref.size()); + } + return err_count == 0; +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index bd65f533839..9b7c3743f7b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1570,7 +1570,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 return make_tuple(number<0>{}, int32x8_t{}); else if constexpr(std::is_same_v) return make_tuple(number<1>{}, int32x8_t{}); - // else if e2m3 => make_tuple(number<2>{}, int32x6_t{}) + else if constexpr(std::is_same_v) + return make_tuple(number<2>{}, pk_fp6x32_t{}); // else if e3m2 => make_tuple(number<3>{}, int32x6_t{}) else if constexpr(std::is_same_v) return make_tuple(number<4>{}, int32x4_t{}); diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index 2a43b596e4c..6b4a827f472 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -20,6 +20,25 @@ struct MemoryCopyParam ck_tile::index_t warp_id; }; +template +struct type_list +{ +}; + +template +struct type_at; + +template +struct type_at> : type_at> +{ +}; + +template +struct type_at<0, type_list> +{ + using type = Head; +}; + template class TestCkTileMemoryCopy : public ::testing::TestWithParam> { @@ -33,48 +52,47 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam ? 1 : 0; ck_tile::HostTensor x_host({m, n}); ck_tile::HostTensor y_host_dev({m, n}); + ck_tile::HostTensor host_init_buf({x_host.get_element_space_size_in_bytes()}); std::cout << "input: " << x_host.mDesc << std::endl; std::cout << "output: " << y_host_dev.mDesc << std::endl; - ck_tile::index_t value = 1; - for(int i = 0; i < m; i++) - { - value = 1; - for(int j = 0; j < n; j++) - { - value = (value + 1) % 127; - x_host(i, j) = static_cast(value); - } - } - + for(size_t i = 0; i < x_host.get_element_space_size_in_bytes(); i++) + host_init_buf.mData[i] = i % 64; + memcpy(x_host.mData.data(), + host_init_buf.mData.data(), + x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, dword_bytes / sizeof(DataType)>; + using BlockTileList = type_list, ck_tile::sequence<16, 96>>; + using VectorList = type_list, + ck_tile::sequence<1, 24>>; + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = type_at::type; + using WaveTile = type_at::type; + using Vector = type_at::type; ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; + using Problem = ck_tile::TileCopyProblem; using Kernel = ck_tile::TileCopy; constexpr ck_tile::index_t kBlockSize = 128; constexpr ck_tile::index_t kBlockPerCu = 1; + // when copy fp6x16 buffer, tread it as int8 buffer and recompute n-dim size. + ck_tile::index_t cpy_n = + cpy_cfg == 1 ? n * sizeof(DataType) / + (sizeof(int8_t) * ck_tile::numeric_traits::PackedSize) + : n; auto ms = launch_kernel( ck_tile::stream_config{nullptr, true}, @@ -85,21 +103,28 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam(x_buf.GetDeviceBuffer()), static_cast(y_buf.GetDeviceBuffer()), m, - n, + cpy_n, warp_id)); - auto bytes = 2 * m * n * sizeof(DataType); + auto bytes = 2 * m * n * sizeof(DataType) / ck_tile::numeric_traits::PackedSize; std::cout << "elapsed: " << ms << " (ms)" << std::endl; std::cout << (bytes * 1e-6 / ms) << " (GB/s)" << std::endl; // reference y_buf.FromDevice(y_host_dev.mData.data()); bool pass = ck_tile::check_err(y_host_dev, x_host); - EXPECT_TRUE(pass); } }; +class TestCkTileMemoryCopyF6x16Async : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyF6x16 : public TestCkTileMemoryCopy +{ +}; + class TestCkTileMemoryCopyHalfAsync : public TestCkTileMemoryCopy { }; @@ -116,6 +141,18 @@ class TestCkTileMemoryCopyFP8Async : public TestCkTileMemoryCopy { }; +TEST_P(TestCkTileMemoryCopyF6x16, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyF6x16Async, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + TEST_P(TestCkTileMemoryCopyHalfAsync, TestCorrectness) { auto [M, N, warp_id] = GetParam(); @@ -140,6 +177,20 @@ TEST_P(TestCkTileMemoryCopyFP8Async, TestCorrectness) this->Run({M, N, warp_id}); } +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16Async, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, TestCkTileMemoryCopyHalfAsync, ::testing::Values(std::tuple{64, 8, 0}, diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index 847763881b5..0583b0f3d86 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -51,12 +51,15 @@ struct TileCopyShape "Inconsistent wave group size!"); }; -template +template struct TileCopyProblem { using XDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool AsyncCopy = AsyncCopy_; + // 0: copy 1, 2, 4 bytes data type + // 1: copy dwordx3 bytes data type + static constexpr int cpy_cfg = cpy_cfg_; }; template @@ -67,6 +70,7 @@ struct TileCopy static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr bool AsyncCopy = Problem::AsyncCopy; + static constexpr int cpy_cfg = Problem::cpy_cfg; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() @@ -98,8 +102,40 @@ struct TileCopy return make_static_tile_distribution(outer_encoding); } + template + // CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Block_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // number of active warps working in this thread block. + constexpr index_t Y2 = + warp_size / X0; // number of threads in a warp needed along M dimension. + constexpr index_t Y1 = + S::Warp_M / + Y2; // number of iterations each warp needs to perform to cover the entire tile window. + constexpr auto outer_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, // Y2==16,X0==4 + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}; + + return make_static_tile_distribution(outer_encoding); + } + CK_TILE_DEVICE void - operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + run_normal_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const { using S = typename Problem::BlockShape; @@ -170,6 +206,124 @@ struct TileCopy move_tile_window(y_block_window, {0, S::Block_N}); } } -}; + CK_TILE_DEVICE void + run_dwordx3_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + constexpr index_t X0 = S::ThreadPerWarp_N; + constexpr index_t X1 = S::Block_N; + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + // LDS buffer + constexpr int dim1_stride = + AsyncCopy ? 16 : 12; // async_load dwordx3 will write 3 bytes & skip 1 bytes in lds. + constexpr int repeat_num = X1 / (X0 * X2); + __shared__ int8_t x_lds[repeat_num * S::Block_M * X0 * dim1_stride]; + + constexpr auto block_dims = make_tuple(number{}, number{}); + constexpr auto block_dims_ = make_tuple(number{}, + number{}, + number{}, + number{}); + constexpr auto block_strides = make_tuple(number{}, + number{}, + number{}, + number<1>{}); + + const auto x_lds_desc_ = + make_naive_tensor_descriptor(block_dims_, block_strides, number<12>{}, number<1>{}); + const auto x_lds_desc = transform_tensor_descriptor( + x_lds_desc_, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number<2>{}, number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = + make_tensor_view(reinterpret_cast(x_lds), x_lds_desc); + + auto x_block_lds_write_window = make_tile_window(x_lds_view, block_dims, {0, 0}); + + auto x_block_lds_read_window = make_tile_window( + x_lds_view, block_dims, {0, 0}, MakeDwordx3DRAMDistribution()); + + const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M); + // Input tensor + const auto x_m_n = + make_naive_tensor_view(reinterpret_cast(p_x), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, block_dims, {iM, 0}, MakeDwordx3DRAMDistribution()); + + // Output tensor + const auto y_m = + make_naive_tensor_view(reinterpret_cast(p_y), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto y_block_window = make_tile_window(y_m, block_dims, {iM, 0}); + + const index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + const index_t my_id = __builtin_amdgcn_readfirstlane(get_warp_id()); + constexpr index_t async_copy_fence_cnt = 0; + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + if(my_id == warp_id) + { + if constexpr(AsyncCopy) + { + async_load_tile(x_block_lds_write_window, x_block_window); + // We don't have prefetch here, wait the data back immediately. + // Wait all asyncload insts complete. + // Wait all waves synced + s_waitcnt_barrier(); + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + else + { + // load from DRAM to registers + auto dram_tile = load_tile(x_block_window); + // store in lds + store_tile(x_block_lds_write_window, dram_tile); + // Wait all lds write insts complete + // Wait all waves synced + block_sync_lds(); + // read from lds to registers + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + } + + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } + + CK_TILE_DEVICE void + operator()(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + if constexpr(cpy_cfg == 1) + { + run_dwordx3_cpy(p_x, p_y, M, N, warp_id); + } + else if constexpr(cpy_cfg == 0) + { + run_normal_cpy(p_x, p_y, M, N, warp_id); + } + else + { + static_assert(false, "unsupported copy config type."); + } + } +}; } // namespace ck_tile From 7965a94e0310a60a027070d7d4d5b45341ac5e87 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Fri, 26 Dec 2025 03:30:03 +0000 Subject: [PATCH 02/14] clang-format --- include/ck_tile/core/arch/amd_buffer_addressing.hpp | 6 +++--- include/ck_tile/core/tensor/buffer_view.hpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 04b814efe41..057e5b20abd 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1460,9 +1460,9 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, else if constexpr(N == 12) { auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 2f112961bbd..af1a8481ccc 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -637,7 +637,7 @@ struct buffer_view, int32_t>; -#elif (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT +#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = std::is_same_v, float> || (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0) From 9ce013168048dfdfa254ea8dc5fbe941f6ed26cb Mon Sep 17 00:00:00 2001 From: ZheWang Date: Fri, 26 Dec 2025 04:14:06 +0000 Subject: [PATCH 03/14] pre-commit --- include/ck_tile/core.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index c3b0d80376a..f43254a31f1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -53,8 +53,8 @@ #include "ck_tile/core/numeric/mxfp_convert.hpp" #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" -#include "ck_tile/core/numeric/pk_fp6.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" From 702b69c66c8727357b48aabb813cdf359b5c0b54 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Mon, 5 Jan 2026 13:37:42 +0000 Subject: [PATCH 04/14] 1st commit --- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 7 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp | 32 +++++ .../18_flatmm/mxgemm/mx_flatmm_instance.cmake | 3 +- .../mxgemm/mx_flatmm_instance.cpp.in | 1 + .../arch/amd_buffer_addressing_builtins.hpp | 2 + include/ck_tile/core/numeric/pk_fp6.hpp | 24 ++-- include/ck_tile/core/numeric/vector_type.hpp | 4 +- include/ck_tile/core/tensor/buffer_view.hpp | 8 +- .../ck_tile/host/reference/reference_gemm.hpp | 22 ++++ include/ck_tile/ops/common/utils.hpp | 1 + ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 34 +++--- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 114 ++++++++++++------ 12 files changed, 186 insertions(+), 66 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 4e7a209b8f6..be2de38ef59 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -296,7 +296,12 @@ int run_mx_flatmm_example(int argc, char* argv[]) } else if(mx_prec == "fp6" || mx_prec == "fp6xfp6") { - throw std::runtime_error("fp6xfp6 is not supported."); + // throw std::runtime_error("fp6xfp6 is not supported."); + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index 0b6185590fa..f5f36840d36 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct MXfp6_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 512; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + struct MXfp8_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake index 5e86cd71332..9250dbe7ae6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST) set(C_LAYOUT ROW) set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16") + set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16") set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16") # foreach(PERSISTENT false true) # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. foreach(PERSISTENT false) - foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8) + foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8) set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}}) string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE}) list(GET DATA_TYPE_AB 0 A_DATA_TYPE) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index 9675d3345b2..a7f65a12eae 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -19,6 +19,7 @@ using FP4 = ck_tile::pk_fp4_t; using FP8 = ck_tile::fp8_t; +using FP6 = ck_tile::pk_fp6x16_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 35a91b2ae80..08aaab4eb96 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1437,6 +1437,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp index 1fd414c6f21..146d294b643 100644 --- a/include/ck_tile/core/numeric/pk_fp6.hpp +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -13,7 +13,7 @@ template struct pk_fp6_t { static constexpr index_t num_bits_elem = 6; - using element_type = uint32_t; // element storage fundamental type + using element_type = int32_t; // element storage fundamental type static constexpr index_t packed_size = pk_size; static constexpr index_t num_bits_vec_elem = sizeof(element_type) * 8; // 32-bit uint for storage @@ -22,22 +22,22 @@ struct pk_fp6_t static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; element_type data_[vector_size]; // packed data using type = pk_fp6_t; - CK_TILE_HOST_DEVICE pk_fp6_t() {}; - CK_TILE_HOST_DEVICE explicit pk_fp6_t(int value) + CK_TILE_HOST_DEVICE constexpr pk_fp6_t(){}; + CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value) { for(size_t i = 0; i < vector_size; ++i) { data_[i] = value; } } - void pack(const uint32_t x, const index_t i) + void pack(const int32_t x, const index_t i) { - uint32_t bits = static_cast(x) & 0x3F; + int32_t bits = static_cast(x) & 0x3F; const int bit_pos = i * num_bits_elem; const int arr_index = bit_pos / num_bits_vec_elem; const int bit_offset = bit_pos % num_bits_vec_elem; const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = data_[arr_index]; + int32_t old_value = data_[arr_index]; // insert bits into the current 32-bit block old_value |= (bits << bit_offset); @@ -46,21 +46,21 @@ struct pk_fp6_t // if it crosses into the next block, shift the remainder if(overhang > 0 && (arr_index + 1) < vector_size) { - uint32_t next_value = data_[arr_index + 1]; + int32_t next_value = data_[arr_index + 1]; next_value |= (bits >> (num_bits_elem - overhang)); data_[arr_index + 1] = next_value; } } template - static inline uint32_t unpack(const T& pk, const index_t i) + static inline int32_t unpack(const T& pk, const index_t i) { const int bit_pos = i * num_bits_elem; const int arr_idx = bit_pos / num_bits_vec_elem; const int bit_offset = bit_pos % num_bits_vec_elem; const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t bits = pk.data_[arr_idx] >> bit_offset; + int32_t bits = pk.data_[arr_idx] >> bit_offset; if(overhang > 0 && (arr_idx + 1) < vector_size) { bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); @@ -69,9 +69,11 @@ struct pk_fp6_t return bits & 0x3F; } - inline uint32_t unpack(const index_t i) const { return unpack(*this, i); } + inline int32_t unpack(const index_t i) const { return unpack(*this, i); } - float fp6_e2m3_to_float(uint32_t fp6_bits) + CK_TILE_HOST_DEVICE int32_t operator[](index_t i) const { return data_[i]; } + + static float fp6_e2m3_to_float(int32_t fp6_bits) { fp6_bits = fp6_bits & 0x3F; diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 9ba8b9cf3eb..a572d40a622 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -63,8 +63,8 @@ template struct ext_vector { static constexpr index_t N = N_; - using value_type = pk_fp6x16_t; - using type = pk_fp6x16_t; // this is danguous + using value_type = pk_fp6_t; + using type = pk_fp6_t; // this is danguous }; template diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index af1a8481ccc..48ccadc2159 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -825,8 +825,12 @@ struct buffer_view>::scalar_type, - scalar_per_t_vector * scalar_per_x_vector>; + constexpr index_t load_elts = scalar_per_t_vector * scalar_per_x_vector; + using buf_t = typename std::conditional< + load_elts == 12 && sizeof(typename X::value_type) == 1, + thread_buffer, + ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>>::type; // using buf_t = ushort __attribute__((ext_vector_type(8))); auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); return bit_cast(rtn); diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264c..01b3ab5df1e 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -617,6 +617,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, a_m_k_scaled(m, k) = a_f4_lo * a_scale; a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto a_scale = ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + a_m_k_scaled(m, k + k_) = + pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale; + } + } else { a_m_k_scaled(m, k) = @@ -645,6 +656,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, b_k_n_scaled(k, n) = b_f4_lo * b_scale; b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto b_scale = ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + b_k_n_scaled(k + k_, n) = + pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale; + } + } else { b_k_n_scaled(k, n) = diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 425083a9de3..4a30e3af163 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -22,6 +22,7 @@ template <> struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp6x16"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index bc7d2323d05..65137061ab5 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -132,8 +132,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 + ? 16 + : 16 /*dwordx4*/ * APackedSize / sizeof(ADataType); + static constexpr index_t BK1 = std::is_same_v + ? 16 + : 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -537,24 +541,26 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_store_lds_window_pong = make_tile_window( // a_lds_block_pong, - make_tuple(number{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); - auto a_warp_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_ping = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_pong = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); // B flat DRAM window for load diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index c4ab1d4a78c..54be5899095 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -17,6 +17,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr index_t kDramLoadPackBytes = 128; static constexpr index_t DWORDx4 = 16; + static constexpr index_t DWORDx3 = 12; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; @@ -77,15 +78,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution() { - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + constexpr index_t K0 = + KPerBlock / APackedSize * sizeof(ADataType) / (K1 * K2); // KPerBlock/256/packsize constexpr index_t M2 = WaveSize / K1; // 8 constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + static_assert(K0 * K1 * K2 == KPerBlock / APackedSize * sizeof(ADataType), "K0, K1, K2 must cover whole KPerBlock!"); return make_static_tile_distribution( @@ -107,9 +109,9 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - const index_t K0 = cols / (K1 * K2 * APackedSize); + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + const index_t K0 = cols / (K1 * K2 / sizeof(ADataType) * APackedSize); const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds @@ -138,19 +140,23 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); - auto&& origin_tmp = window_tmp.get_window_origin(); + auto&& origin_tmp = window_tmp.get_window_origin(); + constexpr index_t test1 = APackedSize / sizeof(ADataType); return make_tile_window(byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / APackedSize}, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / test1}, MakeMX_ABytesDramTileDistribution()); } CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor() { - constexpr index_t K2 = AK1 / APackedSize; // 16 - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256 - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + constexpr index_t K2 = std::is_same_v ? DWORDx3 : AK1 / APackedSize; + constexpr index_t K2_Pad = 16; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + constexpr index_t K0 = std::is_same_v + ? KPerBlock / (K1 * K2 / sizeof(ADataType) * APackedSize) + : KPerBlock / (K1 * AK1); // KPerBlock/256 + static_assert(K0 * K1 * K2 / sizeof(ADataType) * APackedSize == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); constexpr index_t M3 = 4; // so that we can use imm offset to load lds @@ -169,12 +175,12 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - number{}, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, number<1>{}), number{}, number<1>{}); @@ -216,7 +222,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - if constexpr(K_Thread == AK1) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -225,7 +231,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<0, 2>>, sequence<2>, sequence<1>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -235,6 +241,19 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<1, 2>>, sequence<2, 2>, sequence<0, 2>>{}); + else if constexpr(std::is_same_v) + // K_Lane=4, K_Thread=32 + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 2>, + sequence<0, 2>>{}); + else + static_assert(false, "unsupported datatype"); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() @@ -245,17 +264,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - if constexpr(BK1 == K_Thread) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 16 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -265,6 +284,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>{}); + else if constexpr(std::is_same_v) + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 2 1 64 12 + tuple, sequence<2>>, + tuple, sequence<2>>, + sequence<2, 2>, + sequence<0, 3>>{}); + else + static_assert(false, "unsupported datatype"); } template @@ -280,22 +314,25 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple( - flat_n, flat_k / flat_k_per_block, number{})), + make_naive_tensor_descriptor_packed( + make_tuple(flat_n, + flat_k / flat_k_per_block, + number{})), make_tuple(make_pass_through_transform(flat_n), make_merge_transform_v3_division_mod(make_tuple( - flat_k / flat_k_per_block, number{}))), + flat_k / flat_k_per_block, + number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, byte_tensor_desc); - auto&& origin_tmp = window_tmp.get_window_origin(); - return make_tile_window( - byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / BPackedSize}, - MakeMX_BFlatBytesDramTileDistribution()); + auto&& origin_tmp = window_tmp.get_window_origin(); + constexpr index_t test2 = BPackedSize / sizeof(BDataType); + return make_tile_window(byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / test2}, + MakeMX_BFlatBytesDramTileDistribution()); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() @@ -372,7 +409,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + if constexpr(!std::is_same_v) + { + return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } + else + { + return MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } From 11c15624e7a6595f0a8036f13254ab05fd9b23c0 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Tue, 6 Jan 2026 04:34:23 +0000 Subject: [PATCH 05/14] default mnk pass ut --- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 49 ++++++++++++++----- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 8 +-- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index b4d1fe237b3..efa0a9bf290 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -68,24 +68,49 @@ int run_mx_flatmm_with_layouts(int argc, M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); - - if(init_method == 0) - { - ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); - } - else if(init_method == 1) + if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); + auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + std::vector random_bufA(a_buffer_bytes); + std::vector random_bufB(b_buffer_bytes); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(1, 1); + + for(size_t i = 0; i < a_buffer_bytes; ++i) + { + random_bufA[i] = static_cast(dis(gen)); + } + for(size_t i = 0; i < b_buffer_bytes; ++i) + { + random_bufB[i] = static_cast(dis(gen)); + } + memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); + memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); } else { - throw std::runtime_error("wrong! Unexpected init_method"); + if(init_method == 0) + { + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + } + else + { + throw std::runtime_error("wrong! Unexpected init_method"); + } } const auto b_shuffled_host = preShuffleWeight(b_origin_host); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 65137061ab5..15c13d30a21 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -118,8 +118,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); __builtin_amdgcn_sched_barrier(0); From 76ee1de9a116c0de7cf1fef2d73941ea07df955c Mon Sep 17 00:00:00 2001 From: ZheWang Date: Tue, 6 Jan 2026 08:43:32 +0000 Subject: [PATCH 06/14] fix a distrubution --- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 12 ++++-- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 42 ++++++++++++++++--- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 1 + ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 17 +++++--- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 6 +-- 5 files changed, 61 insertions(+), 17 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index be2de38ef59..596eacd7bac 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -115,9 +115,15 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, CDEElementWise, split_k_.value, has_hot_loop_v, - tail_num_v>( - args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + tail_num_v>(args, + ck_tile::stream_config{nullptr, + true, + 1, + n_warmup - n_warmup, + n_repeat - n_repeat + 1, + true, + true, + 50}); }; return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{}) : invoke_splitk_path(std::true_type{}); diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index efa0a9bf290..57aa1b78bee 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -69,24 +69,56 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); if constexpr(std::is_same_v) + // if constexpr(std::is_same_v) { auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); std::vector random_bufA(a_buffer_bytes); std::vector random_bufB(b_buffer_bytes); std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution dis(1, 1); + std::uniform_int_distribution dis(1, 4); + std::uniform_int_distribution dis2(1, 1); for(size_t i = 0; i < a_buffer_bytes; ++i) { - random_bufA[i] = static_cast(dis(gen)); + // auto row= i / 192; + auto col = i % 192; + random_bufA[i] = static_cast(col / 24); + // random_bufA[i] = static_cast(dis2(gen)); } for(size_t i = 0; i < b_buffer_bytes; ++i) { - random_bufB[i] = static_cast(dis(gen)); + random_bufB[i] = static_cast(dis2(gen)); + } + memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); + memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); + } + else if constexpr(std::is_same_v) + { + auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); + auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); + std::vector random_bufA(a_buffer_bytes); + std::vector random_bufB(b_buffer_bytes); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(1, 4); + std::uniform_int_distribution dis2(1, 1); + + for(size_t i = 0; i < a_buffer_bytes; ++i) + { + // auto row= i / 192; + auto col = i % 128; + random_bufA[i] = static_cast(col / 16); + // random_bufA[i] = static_cast(dis2(gen)); + } + for(size_t i = 0; i < b_buffer_bytes; ++i) + { + random_bufB[i] = static_cast(dis2(gen)); } memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index a58d71c7901..8d813d3ab70 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -458,6 +458,7 @@ struct MXFlatmmKernel : FlatmmKernel{}(a_warp_tensor(number<2>{})); + } static_for_product, number, number, @@ -1022,12 +1027,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_ping, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, + number>{}); } }); LastHotLoopScheduler(); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 54be5899095..4a2223d47b9 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -247,11 +247,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tile_distribution_encoding< // sequence, tuple, - sequence>, + sequence>, tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, sequence<2, 2>, - sequence<0, 2>>{}); + sequence<1, 2>>{}); else static_assert(false, "unsupported datatype"); } From 0a67f96bddcd09189d03172be0805bc1c4e91111 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Tue, 6 Jan 2026 09:50:39 +0000 Subject: [PATCH 07/14] fix --- .../mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index eac3bf6a9e8..3e15bfda7e3 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -758,7 +758,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -847,7 +847,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_pong, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished @@ -941,7 +941,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -985,12 +985,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_pong, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, + number>{}); } }); LastHotLoopScheduler(); From 5a931cbd18a4d74782c6a54e234bfdbe56260adb Mon Sep 17 00:00:00 2001 From: ZheWang Date: Tue, 6 Jan 2026 11:47:49 +0000 Subject: [PATCH 08/14] fix bdram distr --- example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc | 14 +++++++++----- ..._flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 57aa1b78bee..f91c0f59ceb 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -85,13 +85,15 @@ int run_mx_flatmm_with_layouts(int argc, for(size_t i = 0; i < a_buffer_bytes; ++i) { // auto row= i / 192; - auto col = i % 192; - random_bufA[i] = static_cast(col / 24); - // random_bufA[i] = static_cast(dis2(gen)); + // auto col = i % 192; + // random_bufA[i] = static_cast(col / 24); + random_bufA[i] = static_cast(dis(gen)); } for(size_t i = 0; i < b_buffer_bytes; ++i) { - random_bufB[i] = static_cast(dis2(gen)); + // auto col = i % 192; + // random_bufB[i] = static_cast(col / 24); + random_bufB[i] = static_cast(dis(gen)); } memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); @@ -118,7 +120,9 @@ int run_mx_flatmm_with_layouts(int argc, } for(size_t i = 0; i < b_buffer_bytes; ++i) { - random_bufB[i] = static_cast(dis2(gen)); + auto col = i % 128; + random_bufB[i] = static_cast(col / 16); + // random_bufB[i] = static_cast(dis2(gen)); } memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 4a2223d47b9..860acb4773b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -289,14 +289,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 2 1 64 12 tuple, sequence<2>>, - tuple, sequence<2>>, + tuple, sequence<1>>, sequence<2, 2>, - sequence<0, 3>>{}); + sequence<2, 3>>{}); else static_assert(false, "unsupported datatype"); } From e1531ba0904d2d046e39089f0a68d878d6428eb7 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Tue, 6 Jan 2026 13:28:17 +0000 Subject: [PATCH 09/14] update --- .../ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc | 10 +++++----- .../mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 17 ++++++++++------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index f91c0f59ceb..b3fc796286b 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -73,8 +73,8 @@ int run_mx_flatmm_with_layouts(int argc, { auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); - ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); std::vector random_bufA(a_buffer_bytes); std::vector random_bufB(b_buffer_bytes); std::random_device rd; @@ -87,13 +87,13 @@ int run_mx_flatmm_with_layouts(int argc, // auto row= i / 192; // auto col = i % 192; // random_bufA[i] = static_cast(col / 24); - random_bufA[i] = static_cast(dis(gen)); + random_bufA[i] = static_cast(dis2(gen)); } for(size_t i = 0; i < b_buffer_bytes; ++i) { // auto col = i % 192; // random_bufB[i] = static_cast(col / 24); - random_bufB[i] = static_cast(dis(gen)); + random_bufB[i] = static_cast(dis2(gen)); } memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); @@ -116,7 +116,7 @@ int run_mx_flatmm_with_layouts(int argc, // auto row= i / 192; auto col = i % 128; random_bufA[i] = static_cast(col / 16); - // random_bufA[i] = static_cast(dis2(gen)); + // random_bufA[i] = static_cast(dis(gen)); } for(size_t i = 0; i < b_buffer_bytes; ++i) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 3e15bfda7e3..861fa3dd7b8 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -628,7 +628,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { @@ -670,7 +670,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> @@ -768,7 +768,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); }; + printf("tailnum %d\n", static_cast(TailNum)); if constexpr(HasHotLoop) { index_t iCounter = (num_loop - 1) / 2; + printf("icounter %d\n", iCounter); do { main_body_implx2(); iCounter--; } while(iCounter > 0); } - // TAIL if constexpr(TailNum == TailNumber::Even) { @@ -955,7 +957,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); Last2ndHotLoopScheduler(); From 1f7fc554e366729f968541bc0ecddba0dfc79e09 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Tue, 6 Jan 2026 13:37:42 +0000 Subject: [PATCH 10/14] pass ut --- example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc | 8 ++++---- .../mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 10 ++-------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index b3fc796286b..9e7d4e004e2 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -73,8 +73,8 @@ int run_mx_flatmm_with_layouts(int argc, { auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); std::vector random_bufA(a_buffer_bytes); std::vector random_bufB(b_buffer_bytes); std::random_device rd; @@ -87,13 +87,13 @@ int run_mx_flatmm_with_layouts(int argc, // auto row= i / 192; // auto col = i % 192; // random_bufA[i] = static_cast(col / 24); - random_bufA[i] = static_cast(dis2(gen)); + random_bufA[i] = static_cast(dis(gen)); } for(size_t i = 0; i < b_buffer_bytes; ++i) { // auto col = i % 192; // random_bufB[i] = static_cast(col / 24); - random_bufB[i] = static_cast(dis2(gen)); + random_bufB[i] = static_cast(dis(gen)); } memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 861fa3dd7b8..23d7a9fca99 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -780,7 +780,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); @@ -874,11 +875,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(TailNum)); if constexpr(HasHotLoop) { index_t iCounter = (num_loop - 1) / 2; - printf("icounter %d\n", iCounter); do { main_body_implx2(); @@ -1001,11 +1000,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}(a_warp_tensor(number<2>{})); - } static_for_product, number, number, From 13749845d93e2c05959bf5cd9de2ab373f6531d7 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Wed, 7 Jan 2026 02:20:20 +0000 Subject: [PATCH 11/14] improve perf --- example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 12 +++--------- example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp | 2 +- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 596eacd7bac..be2de38ef59 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -115,15 +115,9 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, CDEElementWise, split_k_.value, has_hot_loop_v, - tail_num_v>(args, - ck_tile::stream_config{nullptr, - true, - 1, - n_warmup - n_warmup, - n_repeat - n_repeat + 1, - true, - true, - 50}); + tail_num_v>( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); }; return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{}) : invoke_splitk_path(std::true_type{}); diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index f5f36840d36..d4922bb44c7 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -47,7 +47,7 @@ struct MXfp4_FlatmmConfig16 struct MXfp6_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 512; + static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 860acb4773b..ee5ea825eab 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -292,7 +292,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence>, // 2 1 64 12 + DWORDx3>>, // 64 1 2 12 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2, 2>, From 2db47f7ce6e77d50fe0920c20a26926f23c89af6 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Thu, 15 Jan 2026 02:35:37 +0000 Subject: [PATCH 12/14] update --- example/ck_tile/18_flatmm/CMakeLists.txt | 2 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 14 +++-- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 40 +------------ include/ck_tile/core/numeric/vector_type.hpp | 58 ++++++++++++------- include/ck_tile/core/tensor/buffer_view.hpp | 27 ++++++--- 5 files changed, 66 insertions(+), 75 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 7451ee25b02..c4e541e31e9 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -21,7 +21,7 @@ if(has_supported_gpu) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") - + list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-gnu-line-marker) add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index be2de38ef59..1141717545c 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -296,12 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[]) } else if(mx_prec == "fp6" || mx_prec == "fp6xfp6") { - // throw std::runtime_error("fp6xfp6 is not supported."); - return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); } else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") { diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 9e7d4e004e2..54c23e22662 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -69,7 +69,6 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); if constexpr(std::is_same_v) - // if constexpr(std::is_same_v) { auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); @@ -80,50 +79,13 @@ int run_mx_flatmm_with_layouts(int argc, std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution dis(1, 4); - std::uniform_int_distribution dis2(1, 1); for(size_t i = 0; i < a_buffer_bytes; ++i) - { - // auto row= i / 192; - // auto col = i % 192; - // random_bufA[i] = static_cast(col / 24); random_bufA[i] = static_cast(dis(gen)); - } + for(size_t i = 0; i < b_buffer_bytes; ++i) - { - // auto col = i % 192; - // random_bufB[i] = static_cast(col / 24); random_bufB[i] = static_cast(dis(gen)); - } - memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); - memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); - } - else if constexpr(std::is_same_v) - { - auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); - auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); - ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); - std::vector random_bufA(a_buffer_bytes); - std::vector random_bufB(b_buffer_bytes); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(1, 4); - std::uniform_int_distribution dis2(1, 1); - for(size_t i = 0; i < a_buffer_bytes; ++i) - { - // auto row= i / 192; - auto col = i % 128; - random_bufA[i] = static_cast(col / 16); - // random_bufA[i] = static_cast(dis(gen)); - } - for(size_t i = 0; i < b_buffer_bytes; ++i) - { - auto col = i % 128; - random_bufB[i] = static_cast(col / 16); - // random_bufB[i] = static_cast(dis2(gen)); - } memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); } diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index a572d40a622..3c2f90466b9 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -46,26 +46,10 @@ struct ext_vector -struct ext_vector -{ - static constexpr index_t N = 12; - using value_type = int32x3_t; - using type = int32x3_t; -}; - -template -struct ext_vector -{ - static constexpr index_t N = N_; - using value_type = pk_fp6_t; - using type = pk_fp6_t; // this is danguous -}; +// struct int32x3_t +// { +// int data[3]; +// }; template struct ext_vector::type>>> @@ -180,6 +164,40 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x64_t = int32_t __attribute__((ext_vector_type(64))); +struct int32x3_tt +{ + int32_t data[3]; +}; + +struct int32x6_tt +{ + int32_t data[6]; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 12; + using value_type = int32x3_tt; + using type = int32x3_tt; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 1; + using value_type = int32x3_tt; + using type = int32x3_tt; // this is danguous +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 2; + using value_type = int32x6_tt; + using type = int32x6_tt; // this is danguous +}; + // u32 // using uint32_t = ... using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 48ccadc2159..d98f362f137 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -303,7 +303,8 @@ struct buffer_view,t_per_x>{}; if constexpr(use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; @@ -826,14 +827,22 @@ struct buffer_view, - ext_vector_t>::scalar_type, - scalar_per_t_vector * scalar_per_x_vector>>::type; - // using buf_t = ushort __attribute__((ext_vector_type(8))); - auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); - return bit_cast(rtn); + if constexpr(load_elts == 12 && sizeof(typename X::value_type) == 1) + { + auto rtn = reinterpret_cast(p_data_) + (i + linear_offset) / 4; + struct + { + int32_t x, y, z; + } tmp = {rtn[0], rtn[1], rtn[2]}; + return bit_cast(tmp); + } + else + { + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); + return bit_cast(rtn); + } #endif } else From 6e5fa193f4088cc5352bba2356522f3e8e9922e9 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Mon, 19 Jan 2026 03:52:19 +0000 Subject: [PATCH 13/14] clean code --- example/ck_tile/18_flatmm/CMakeLists.txt | 1 - include/ck_tile/core/numeric/vector_type.hpp | 5 ----- include/ck_tile/core/tensor/buffer_view.hpp | 2 -- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 1 - ...tmm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 15 +++++++++------ test/ck_tile/memory_copy/test_copy.cpp | 18 +++++++++--------- test/ck_tile/memory_copy/test_copy.hpp | 10 +++++----- 7 files changed, 23 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index c4e541e31e9..d77e3c93229 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -21,7 +21,6 @@ if(has_supported_gpu) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") - list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-gnu-line-marker) add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 3c2f90466b9..29178e956e5 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -46,11 +46,6 @@ struct ext_vector struct ext_vector::type>>> { diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index d98f362f137..59f82939b99 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -303,8 +303,6 @@ struct buffer_view,t_per_x>{}; if constexpr(use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 8d813d3ab70..a58d71c7901 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -458,7 +458,6 @@ struct MXFlatmmKernel : FlatmmKernel(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, byte_tensor_desc); - auto&& origin_tmp = window_tmp.get_window_origin(); - constexpr index_t test2 = BPackedSize / sizeof(BDataType); - return make_tile_window(byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / test2}, - MakeMX_BFlatBytesDramTileDistribution()); + auto&& origin_tmp = window_tmp.get_window_origin(); + auto origin_n = origin_tmp[0]; + auto origin_k = static_cast(origin_tmp[1] * sizeof(BDataType) / BPackedSize); + return make_tile_window( + byte_tensor_view, + make_tuple(number{}, + number{}), + {origin_n, origin_k}, + MakeMX_BFlatBytesDramTileDistribution()); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index 6b4a827f472..208b92e7026 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -52,8 +52,8 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam ? 1 : 0; + constexpr auto dword_bytes = 4; + const ck_tile::index_t CpyCfg = std::is_same_v ? 1 : 0; ck_tile::HostTensor x_host({m, n}); ck_tile::HostTensor y_host_dev({m, n}); @@ -75,24 +75,24 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam, ck_tile::sequence<1, 24>>; using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = type_at::type; - using WaveTile = type_at::type; - using Vector = type_at::type; + using BlockTile = type_at::type; + using WaveTile = type_at::type; + using Vector = type_at::type; ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; + using Problem = ck_tile::TileCopyProblem; using Kernel = ck_tile::TileCopy; constexpr ck_tile::index_t kBlockSize = 128; constexpr ck_tile::index_t kBlockPerCu = 1; // when copy fp6x16 buffer, tread it as int8 buffer and recompute n-dim size. ck_tile::index_t cpy_n = - cpy_cfg == 1 ? n * sizeof(DataType) / - (sizeof(int8_t) * ck_tile::numeric_traits::PackedSize) - : n; + CpyCfg == 1 ? n * sizeof(DataType) / + (sizeof(int8_t) * ck_tile::numeric_traits::PackedSize) + : n; auto ms = launch_kernel( ck_tile::stream_config{nullptr, true}, diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index 0583b0f3d86..2ce4982a04d 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -51,7 +51,7 @@ struct TileCopyShape "Inconsistent wave group size!"); }; -template +template struct TileCopyProblem { using XDataType = remove_cvref_t; @@ -59,7 +59,7 @@ struct TileCopyProblem static constexpr bool AsyncCopy = AsyncCopy_; // 0: copy 1, 2, 4 bytes data type // 1: copy dwordx3 bytes data type - static constexpr int cpy_cfg = cpy_cfg_; + static constexpr int CpyCfg = CpyCfg_; }; template @@ -70,7 +70,7 @@ struct TileCopy static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr bool AsyncCopy = Problem::AsyncCopy; - static constexpr int cpy_cfg = Problem::cpy_cfg; + static constexpr int CpyCfg = Problem::CpyCfg; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() @@ -312,11 +312,11 @@ struct TileCopy CK_TILE_DEVICE void operator()(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const { - if constexpr(cpy_cfg == 1) + if constexpr(CpyCfg == 1) { run_dwordx3_cpy(p_x, p_y, M, N, warp_id); } - else if constexpr(cpy_cfg == 0) + else if constexpr(CpyCfg == 0) { run_normal_cpy(p_x, p_y, M, N, warp_id); } From 41743df85deab167c71f577803a88ee1bb1f7363 Mon Sep 17 00:00:00 2001 From: illsilin_amdeng Date: Mon, 19 Jan 2026 08:02:00 -0800 Subject: [PATCH 14/14] fix clang format --- example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index a7f65a12eae..e6d612f0d6e 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -19,7 +19,7 @@ using FP4 = ck_tile::pk_fp4_t; using FP8 = ck_tile::fp8_t; -using FP6 = ck_tile::pk_fp6x16_t; +using FP6 = ck_tile::pk_fp6x16_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t;