Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion example/ck_tile/18_flatmm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ if(has_supported_gpu)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")

add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

Expand Down
18 changes: 13 additions & 5 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
int KPack =
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);

ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));

Expand Down Expand Up @@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[])
}
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
throw std::runtime_error("fp6xfp6 is not supported.");
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
ck_tile::pk_fp6x16_t,
ck_tile::fp16_t,
MXfp6_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
Expand Down
32 changes: 32 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16
static constexpr bool TiledMMAPermuteN = false;
};

struct MXfp6_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;

static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;

static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;

static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;

static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};

struct MXfp8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
Expand Down
3 changes: 2 additions & 1 deletion example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")

# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP6 = ck_tile::pk_fp6x16_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;

Expand Down
51 changes: 37 additions & 14 deletions example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc,
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));

if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
auto a_buffer_bytes = a_host.get_element_space_size_in_bytes();
auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes();
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
std::vector<int8_t> random_bufA(a_buffer_bytes);
std::vector<int8_t> random_bufB(b_buffer_bytes);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> dis(1, 4);

for(size_t i = 0; i < a_buffer_bytes; ++i)
random_bufA[i] = static_cast<int8_t>(dis(gen));

for(size_t i = 0; i < b_buffer_bytes; ++i)
random_bufB[i] = static_cast<int8_t>(dis(gen));

memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes);
memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
}
}

const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
Expand Down
2 changes: 1 addition & 1 deletion include/ck/utility/amd_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 ; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
0,
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_fp6.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
Expand Down
11 changes: 10 additions & 1 deletion include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");

using rtn_type = thread_buffer<int8_t, N>;
Expand Down Expand Up @@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,

return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 12)
{
auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));

return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
Expand Down
56 changes: 50 additions & 6 deletions include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,23 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");

CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32");

CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3(
dwordx3_union vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc)
{
int32x3_t v_reg;
v_reg[0] = vdata.as_i32[0];
v_reg[1] = vdata.as_i32[1];
v_reg[2] = vdata.as_i32[2];
llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, glc_slc);
};

CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
Expand Down Expand Up @@ -1285,7 +1302,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 | N == 16 || N == 32 || N == 64,
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect use of bitwise OR operator | instead of logical OR ||. This should be N == 12 || N == 16 to properly check if N equals 12 or 16.

Suggested change
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 | N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,

Copilot uses AI. Check for mistakes.
"wrong! not implemented");

using rtn_type = thread_buffer<int8_t, N>;
Expand Down Expand Up @@ -1325,6 +1342,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,

return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 12)
{
auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
dwordx3_union ret;
ret.as_i32[0] = tmp[0];
ret.as_i32[1] = tmp[1];
ret.as_i32[2] = tmp[2];
return bit_cast<rtn_type>(ret);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
Expand Down Expand Up @@ -1406,15 +1435,19 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
(std::is_same<T, uint8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
(std::is_same<T, e8m0_bexp_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))) ||
(std::is_same<T, pk_fp6x16_t>::value && (N == 1)),
"wrong! not implemented");

using rtn_type = thread_buffer<T, N>;
Expand Down Expand Up @@ -1745,7 +1778,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");

if constexpr(N == 1)
Expand Down Expand Up @@ -1781,6 +1814,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 12)
{
llvm_amdgcn_raw_buffer_store_i32x3(bit_cast<dwordx3_union>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
Expand Down Expand Up @@ -1854,10 +1895,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
(std::is_same<T, uint16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(std::is_same<T, uint8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
std::is_same<T, pk_fp6x16_t>::value && (N == 1),
"wrong! not implemented");

if constexpr(std::is_same<T, float>::value) // fp32
Expand Down
Loading