From 03d845aaabb0a105c07cbc26564fda719f90ac8c Mon Sep 17 00:00:00 2001 From: yingluAMD Date: Wed, 7 Jan 2026 13:58:46 +0800 Subject: [PATCH 1/7] ck_tile:tf32:add fp32 example --- example/ck_tile/03_gemm/gemm_basic.cpp | 7 +++++++ example/ck_tile/03_gemm/gemm_basic_invoker.hpp | 14 ++++++++------ example/ck_tile/03_gemm/gemm_utils.hpp | 11 ++++++++++- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e30ae8319f2..c9cfeaf37c8 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -9,6 +9,8 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) { + arg_parser.print(); + std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); @@ -41,6 +43,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) return run_gemm_example_prec_type( a_layout, b_layout, arg_parser); } + else if(data_type == "fp32") + { + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); + } else if(data_type == "fp8") { return run_gemm_example_prec_type; + // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t M_Tile = is_fp32 ? 128 : 256; + constexpr ck_tile::index_t N_Tile = is_fp32 ? 128 : 256; constexpr ck_tile::index_t K_Tile = 64; #if CK_TILE_USE_WMMA @@ -38,12 +40,12 @@ struct BasicInvoker constexpr ck_tile::index_t N_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = 16; #else - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t M_Warp = is_fp32 ? 4 : 2; + constexpr ck_tile::index_t N_Warp = is_fp32 ? 4 : 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = is_fp32 ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_fp32 ? 16 : 32; constexpr ck_tile::index_t K_Warp_Tile = 16; #endif diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 8eff0e7469f..0094ddf250a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -302,6 +302,15 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill
 struct GemmTypeConfig;
 
+template <>
+struct GemmTypeConfig
+{
+    using ADataType   = float;
+    using BDataType   = float;
+    using AccDataType = float;
+    using CDataType   = float;
+};
+
 template <>
 struct GemmTypeConfig
 {
@@ -446,7 +455,7 @@ inline auto create_args()
         .insert("stride_b", "0", "Tensor B stride")
         .insert("stride_c", "0", "Tensor C stride")
         .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
-        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
+        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t")
         .insert("warmup", "50", "number of iterations before benchmark the kernel")
         .insert("repeat", "100", "number of iterations to benchmark the kernel")
         .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

From 2ae90340052464853791051513cdeaea3db3f81b Mon Sep 17 00:00:00 2001
From: yingluAMD 
Date: Wed, 7 Jan 2026 17:26:35 +0800
Subject: [PATCH 2/7] add tf32 examle on gfx942

---
 example/ck_tile/03_gemm/gemm_basic.cpp        |   9 ++
 .../ck_tile/03_gemm/gemm_basic_invoker.hpp    |  31 +++--
 example/ck_tile/03_gemm/gemm_utils.hpp        |  11 +-
 example/ck_tile/03_gemm/run_gemm_example.inc  |  56 ++++----
 .../03_gemm/run_gemm_example_common.hpp       |  11 +-
 include/ck_tile/core/numeric/numeric.hpp      |  21 +++
 include/ck_tile/core/numeric/type_convert.hpp |  27 ++++
 .../ck_tile/host/reference/reference_gemm.hpp | 115 +++++++++++-----
 ...emm_universal_pipeline_ag_bg_cr_policy.hpp |  41 +++---
 include/ck_tile/ops/gemm/warp/warp_gemm.hpp   |  20 +++
 .../warp/warp_gemm_attribute_mfma_impl.hpp    | 128 ++++++++++++++++++
 .../ops/gemm/warp/warp_gemm_dispatcher.hpp    |   5 +
 .../ops/gemm/gemm_universal/gemm_common.hpp   |   6 +
 13 files changed, 391 insertions(+), 90 deletions(-)

diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp
index c9cfeaf37c8..2ca76623134 100644
--- a/example/ck_tile/03_gemm/gemm_basic.cpp
+++ b/example/ck_tile/03_gemm/gemm_basic.cpp
@@ -48,6 +48,15 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
         return run_gemm_example_prec_type(
             a_layout, b_layout, arg_parser);
     }
+    else if(data_type == "tf32")
+    {
+        return run_gemm_example_prec_type(a_layout, b_layout, arg_parser);
+    }
     else if(data_type == "fp8")
     {
         return run_gemm_example_prec_type
+              typename CDEElementWise,
+              typename ComputeDataType = ADataType>
     static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
     {
         if constexpr(Persistent)
@@ -24,11 +25,11 @@ struct BasicInvoker
             std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
         }
 
-        constexpr bool is_fp32 = std::is_same_v;
+        constexpr bool is_fp32_input = std::is_same_v;
 
         // This part comes from the Codegen
-        constexpr ck_tile::index_t M_Tile = is_fp32 ? 128 : 256;
-        constexpr ck_tile::index_t N_Tile = is_fp32 ? 128 : 256;
+        constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
+        constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256;
         constexpr ck_tile::index_t K_Tile = 64;
 
 #if CK_TILE_USE_WMMA
@@ -40,12 +41,12 @@ struct BasicInvoker
         constexpr ck_tile::index_t N_Warp_Tile = 16;
         constexpr ck_tile::index_t K_Warp_Tile = 16;
 #else
-        constexpr ck_tile::index_t M_Warp = is_fp32 ? 4 : 2;
-        constexpr ck_tile::index_t N_Warp = is_fp32 ? 4 : 2;
+        constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2;
+        constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2;
         constexpr ck_tile::index_t K_Warp = 1;
 
-        constexpr ck_tile::index_t M_Warp_Tile = is_fp32 ? 16 : 32;
-        constexpr ck_tile::index_t N_Warp_Tile = is_fp32 ? 16 : 32;
+        constexpr ck_tile::index_t M_Warp_Tile = is_fp32_input ? 16 : 32;
+        constexpr ck_tile::index_t N_Warp_Tile = is_fp32_input ? 16 : 32;
         constexpr ck_tile::index_t K_Warp_Tile = 16;
 #endif
 
@@ -63,11 +64,15 @@ struct BasicInvoker
                                                           BLayout,
                                                           CLayout>;
 
-        using CodegenPipelineProblem = ck_tile::GemmPipelineProblem;
+        using CodegenPipelineProblem =
+            ck_tile::GemmPipelineProblem;
 
         using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1;
 
diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp
index 0094ddf250a..3b8ecde15e9 100644
--- a/example/ck_tile/03_gemm/gemm_utils.hpp
+++ b/example/ck_tile/03_gemm/gemm_utils.hpp
@@ -311,6 +311,15 @@ struct GemmTypeConfig
     using CDataType   = float;
 };
 
+template <>
+struct GemmTypeConfig
+{
+    using ADataType   = float;
+    using BDataType   = float;
+    using AccDataType = float;
+    using CDataType   = float;
+};
+
 template <>
 struct GemmTypeConfig
 {
@@ -455,7 +464,7 @@ inline auto create_args()
         .insert("stride_b", "0", "Tensor B stride")
         .insert("stride_c", "0", "Tensor C stride")
         .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
-        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t")
+        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32")
         .insert("warmup", "50", "number of iterations before benchmark the kernel")
         .insert("repeat", "100", "number of iterations to benchmark the kernel")
         .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc
index 78f3a9b0b3f..cb55a6bce19 100644
--- a/example/ck_tile/03_gemm/run_gemm_example.inc
+++ b/example/ck_tile/03_gemm/run_gemm_example.inc
@@ -109,7 +109,8 @@ template 
+          typename ComputeDataType = ADataType,
+          typename CDEElementWise  = ck_tile::element_wise::PassThrough>
 float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                   ck_tile::DeviceMem& b_k_n_dev_buf,
                   ck_tile::DeviceMem& c_m_n_dev_buf,
@@ -151,7 +152,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                                           DsLayout,
                                           CLayout,
                                           true,
-                                          CDEElementWise>(
+                                          CDEElementWise,
+                                          ComputeDataType>(
             args,
             ck_tile::stream_config{
                 nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
@@ -169,7 +171,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                                           DsLayout,
                                           CLayout,
                                           false,
-                                          CDEElementWise>(
+                                          CDEElementWise,
+                                          ComputeDataType>(
             args,
             ck_tile::stream_config{
                 nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
@@ -209,11 +212,12 @@ std::tuple inline parse_ge
 template 
+          typename BDataType       = ADataType,
+          typename CDataType       = ADataType,
+          typename ALayout         = ck_tile::tensor_layout::gemm::RowMajor,
+          typename BLayout         = ck_tile::tensor_layout::gemm::ColumnMajor,
+          typename CLayout         = ck_tile::tensor_layout::gemm::RowMajor,
+          typename ComputeDataType = ADataType>
 int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                   const ALayout a_layout                  = ALayout{},
                                   const BLayout b_layout                  = BLayout{},
@@ -349,21 +353,22 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                  ALayout,
                                  BLayout,
                                  ck_tile::tuple<>,
-                                 CLayout>(a_m_k_dev_buf,
-                                          b_k_n_dev_buf,
-                                          c_m_n_dev_buf,
-                                          M,
-                                          N,
-                                          K,
-                                          stride_A,
-                                          stride_B,
-                                          stride_C,
-                                          kbatch,
-                                          n_warmup,
-                                          n_repeat,
-                                          persistent,
-                                          flush_cache,
-                                          rotating_count);
+                                 CLayout,
+                                 ComputeDataType>(a_m_k_dev_buf,
+                                                  b_k_n_dev_buf,
+                                                  c_m_n_dev_buf,
+                                                  M,
+                                                  N,
+                                                  K,
+                                                  stride_A,
+                                                  stride_B,
+                                                  stride_C,
+                                                  kbatch,
+                                                  n_warmup,
+                                                  n_repeat,
+                                                  persistent,
+                                                  flush_cache,
+                                                  rotating_count);
 
     c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
 
@@ -393,7 +398,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
 
     if(arg_parser.get_int("v") == 1)
     {
-        ck_tile::reference_gemm(
+        ck_tile::reference_gemm(
             a_m_k, b_k_n, c_m_n_ref);
         const float max_accumulated_value =
             *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
@@ -427,7 +432,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                     CDataType,
                                     ALayout,
                                     BLayout,
-                                    CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
+                                    CLayout,
+                                    ComputeDataType>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
 
         c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
 
diff --git a/example/ck_tile/03_gemm/run_gemm_example_common.hpp b/example/ck_tile/03_gemm/run_gemm_example_common.hpp
index e6a1c626e5a..e65328248ac 100644
--- a/example/ck_tile/03_gemm/run_gemm_example_common.hpp
+++ b/example/ck_tile/03_gemm/run_gemm_example_common.hpp
@@ -6,8 +6,9 @@
 template 
+          typename BPrecType       = APrecType,
+          typename CPrecType       = APrecType,
+          typename ComputeDataType = APrecType>
 int run_gemm_example_prec_type(std::string a_layout,
                                std::string b_layout,
                                ck_tile::ArgParser& arg_parser)
@@ -54,7 +55,11 @@ int run_gemm_example_prec_type(std::string a_layout,
                                                      Invoker,
                                                      APrecType,
                                                      BPrecType,
-                                                     CPrecType>(
+                                                     CPrecType,
+                                                     decltype(a_layout_type),
+                                                     decltype(b_layout_type),
+                                                     Row,
+                                                     ComputeDataType>(
                     arg_parser, a_layout_type, b_layout_type, Row{});
             }
         },
diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp
index b2bd6286851..e581cae0ab9 100644
--- a/include/ck_tile/core/numeric/numeric.hpp
+++ b/include/ck_tile/core/numeric/numeric.hpp
@@ -9,6 +9,8 @@
 
 namespace ck_tile {
 
+using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits
+
 // this struct has the information of
 // 1. limit of a certain type, simliar to std::numeric_limits
 // 2. some pre-defined value, zero, one...
@@ -101,6 +103,25 @@ struct numeric_traits
     using bitwise_type                  = uint32_t;
 };
 
+template <>
+struct numeric_traits
+{
+    static constexpr int exp            = 8;
+    static constexpr int mant           = 10;
+    static constexpr int bias           = 127;
+    static constexpr uint32_t nan_mask  = 0x7F800000;
+    static constexpr uint32_t head_mask = 0xFF800000;
+    static constexpr uint32_t mant_mask = 0x7FFFFF;
+    static constexpr uint32_t exp_mask  = 0xFF;
+    static constexpr uint32_t abs_mask  = 0x7FFFFFFF;
+    static constexpr uint32_t Inf       = 0x7F800000;
+    static constexpr uint32_t NegInf    = 0xFF800000;
+    static constexpr uint32_t NaN       = 0x7F800001;
+    static constexpr uint32_t Neg0      = 0x80000000;
+    static constexpr int PackedSize     = 1;
+    using bitwise_type                  = uint32_t;
+};
+
 } // namespace ck_tile
 
 #define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)                                       \
diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp
index deaa9e0bd90..b1a9ce6e58d 100644
--- a/include/ck_tile/core/numeric/type_convert.hpp
+++ b/include/ck_tile/core/numeric/type_convert.hpp
@@ -57,6 +57,33 @@ CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
 CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
 CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
 
+enum class tf32_rounding_mode
+{
+    truncate = 0,
+    standard = 1, // RTNE
+};
+
+template 
+CK_TILE_HOST_DEVICE constexpr float float_to_tf32(float x)
+{
+    uint32_t i = bit_cast(x);
+    if constexpr(rounding == tf32_rounding_mode::standard)
+    {
+        if((i & 0x7f800000) != 0x7f800000)
+        {
+            i += 0xfff + ((i >> 13) & 1);
+        }
+    }
+    i &= 0xFFFFE000u;
+    return bit_cast(i);
+}
+
+template , bool> = false>
+CK_TILE_HOST_DEVICE constexpr float type_convert(float x)
+{
+    return float_to_tf32(x);
+}
+
 CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
 CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
 CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp
index 9ad5af8264c..0074e0c5952 100644
--- a/include/ck_tile/host/reference/reference_gemm.hpp
+++ b/include/ck_tile/host/reference/reference_gemm.hpp
@@ -435,9 +435,10 @@ template 
+          typename ComputeDataType = ADataType,
+          typename AElementOp      = ck_tile::identity,
+          typename BElementOp      = ck_tile::identity,
+          typename ACCElementOp    = ck_tile::identity>
 CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
                                  const HostTensor& b_k_n,
                                  HostTensor& c_m_n,
@@ -482,7 +483,16 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
             {
                 v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n)));
             }
-            v_acc += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+                v_acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
+                                                            ck_tile::type_convert(v_b));
+            }
+            else
+            {
+                v_acc += v_a * v_b;
+            }
         }
 
         c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc));
@@ -717,7 +727,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 __global__ void naive_gemm_kernel(ADataType* A,
                                   BDataType* B,
                                   CDataType* C,
@@ -789,7 +800,16 @@ __global__ void naive_gemm_kernel(ADataType* A,
             {
                 v_b = ck_tile::type_convert(B[b_index]);
             }
-            acc += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+                acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
+                                                          ck_tile::type_convert(v_b));
+            }
+            else
+            {
+                acc += v_a * v_b;
+            }
         }
 
         int c_index = (std::is_same_v)
@@ -805,7 +825,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 __global__ void blockwise_gemm_kernel(ADataType* A,
                                       BDataType* B,
                                       CDataType* C,
@@ -902,7 +923,16 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
             {
                 v_b = ck_tile::type_convert(B[b_index]);
             }
-            acc_temp += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+                acc_temp += ck_tile::type_convert(ck_tile::type_convert(v_a) *
+                                                               ck_tile::type_convert(v_b));
+            }
+            else
+            {
+                acc_temp += v_a * v_b;
+            }
         }
         // final accumulation
         acc += acc_temp * scale_A * scale_B;
@@ -920,7 +950,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 void reference_gemm_gpu(ADataType* a_ptr,
                         BDataType* b_ptr,
                         CDataType* c_ptr,
@@ -935,9 +966,15 @@ void reference_gemm_gpu(ADataType* a_ptr,
     int numThreadsPerBlock = 256; // Common choice for threads per block
     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
-    naive_gemm_kernel
-        <<>>(
-            a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
+    naive_gemm_kernel<<>>(
+        a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
 
     return;
 }
@@ -948,7 +985,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 void reference_blockwise_gemm_gpu(ADataType* a_ptr,
                                   BDataType* b_ptr,
                                   CDataType* c_ptr,
@@ -968,21 +1006,27 @@ void reference_blockwise_gemm_gpu(ADataType* a_ptr,
     int numThreadsPerBlock = 256; // Common choice for threads per block
     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
-    blockwise_gemm_kernel
-        <<>>(a_ptr,
-                                            b_ptr,
-                                            c_ptr,
-                                            M,
-                                            N,
-                                            K,
-                                            stride_a,
-                                            stride_b,
-                                            stride_c,
-                                            scale_granularity_m,
-                                            scale_granularity_n,
-                                            scale_granularity_k,
-                                            scale_A_ptr,
-                                            scale_B_ptr);
+    blockwise_gemm_kernel<<>>(a_ptr,
+                                                                              b_ptr,
+                                                                              c_ptr,
+                                                                              M,
+                                                                              N,
+                                                                              K,
+                                                                              stride_a,
+                                                                              stride_b,
+                                                                              stride_c,
+                                                                              scale_granularity_m,
+                                                                              scale_granularity_n,
+                                                                              scale_granularity_k,
+                                                                              scale_A_ptr,
+                                                                              scale_B_ptr);
 
     return;
 }
@@ -993,7 +1037,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 void reference_batched_gemm_gpu(ADataType* a_ptr,
                                 BDataType* b_ptr,
                                 CDataType* c_ptr,
@@ -1017,9 +1062,15 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
         ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
         BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
         CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
-        naive_gemm_kernel
-            <<>>(
-                d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
+        naive_gemm_kernel<<>>(
+            d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
     }
 
     return;
diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
index 6199142d986..b3531496fe2 100644
--- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
@@ -157,17 +157,18 @@ struct UniversalGemmBasePolicy
                 constexpr auto KThreadWrite     = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
                 constexpr auto K0PerThreadWrite = AK0 / KThreadWrite;
                 constexpr auto KThreadRead      = get_warp_size() / MPerXdl;
-                constexpr auto K0PerThreadRead  = AK0 / KThreadRead;
+                constexpr auto K0PerThreadRead  = (AK0 / KThreadRead) > 0 ? (AK0 / KThreadRead) : 1;
 
                 // check if we exceed all LDS banks
                 constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
-                constexpr auto kfold         = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
-                                                   ? 1
-                                                   : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
+                constexpr auto kfold          = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth ||
+                                        (AK1 * M0 * sizeof(ADataType)) == 0)
+                                                    ? 1
+                                                    : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
+                constexpr auto divisor        = (kfold * K0PerThreadWrite / K0PerThreadRead);
+                constexpr auto divisor_to_use = divisor > 0 ? divisor : 1;
                 constexpr auto KThreadReadPerm =
-                    (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
-                        ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
-                        : KThreadRead;
+                    (divisor > 1) ? KThreadRead / divisor_to_use : KThreadRead;
 
                 // 1<=mpair<=n0
                 constexpr auto mpair =
@@ -354,17 +355,18 @@ struct UniversalGemmBasePolicy
                 constexpr auto KThreadWrite     = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
                 constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
                 constexpr auto KThreadRead      = get_warp_size() / NPerXdl;
-                constexpr auto K0PerThreadRead  = BK0 / KThreadRead;
+                constexpr auto K0PerThreadRead  = (BK0 / KThreadRead) > 0 ? (BK0 / KThreadRead) : 1;
 
                 // check if we exceed all LDS banks
                 constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
-                constexpr auto kfold         = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
-                                                   ? 1
-                                                   : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
+                constexpr auto kfold          = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth ||
+                                        (BK1 * N0 * sizeof(BDataType)) == 0)
+                                                    ? 1
+                                                    : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
+                constexpr auto divisor        = (kfold * K0PerThreadWrite / K0PerThreadRead);
+                constexpr auto divisor_to_use = divisor > 0 ? divisor : 1;
                 constexpr auto KThreadReadPerm =
-                    (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
-                        ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
-                        : KThreadRead;
+                    (divisor > 1) ? KThreadRead / divisor_to_use : KThreadRead;
 
                 // 1<=npair<=n0
                 constexpr auto npair =
@@ -897,6 +899,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
 
         using ADataType = remove_cvref_t;
         using BDataType = remove_cvref_t;
+        using ComputeDataType = remove_cvref_t;
+
         using ATypeToUse =
             std::conditional_t, BDataType, ADataType>;
         using BTypeToUse = std::conditional_t ||
@@ -904,8 +908,13 @@ struct UniversalGemmPipelineAgBgCrPolicy
                                               ADataType,
                                               BDataType>;
 
-        using WarpGemm = WarpGemmDispatcher, tf32_t, ATypeToUse>;
+        using BTypeForDispatcher =
+            std::conditional_t, tf32_t, BTypeToUse>;
+
+        using WarpGemm = WarpGemmDispatcher>;
 
+// tf32
+
+using WarpGemmMfmaTf32Tf32F32M16N16K8 = WarpGemmImpl<
+    WarpGemmAttributeMfma>>;
+
+using WarpGemmMfmaTf32Tf32F32M32N32K4 = WarpGemmImpl<
+    WarpGemmAttributeMfma>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M16N16K16 = WarpGemmImpl,
+    2,
+    AttrNumAccess>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl,
+    4,
+    AttrNumAccess>>;
+
 // fp16
 
 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
index bd65f533839..a79011899a2 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
@@ -190,6 +190,134 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
     }
 };
 
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M32N32K4Tf32
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 32;
+    static constexpr index_t kN = 32;
+    static constexpr index_t kK = 4;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 32;
+    static constexpr index_t kBNLane     = 32;
+    static constexpr index_t kABKLane    = 2;
+    static constexpr index_t kABKPerLane = 2;
+
+    static constexpr index_t kCMLane     = 2;
+    static constexpr index_t kCNLane     = 32;
+    static constexpr index_t kCM0PerLane = 4;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+        DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x4_xf32", Ctrl)
+        else
+        {
+#if defined(__gfx942__)
+            c_vec = __builtin_amdgcn_mfma_f32_32x32x4_xf32(a_vec, b_vec, c_vec, 0, 0, 0);
+#else
+            ck_tile::ignore = c_vec;
+            ck_tile::ignore = a_vec;
+            ck_tile::ignore = b_vec;
+#endif
+        }
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+#if defined(__gfx942__)
+        return bit_cast(
+            __builtin_amdgcn_mfma_f32_32x32x4_xf32(a_vec, b_vec, CVecType{0.f}, 0, 0, 0));
+#else
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+        return CVecType{0.f};
+#endif
+    }
+};
+
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M16N16K8Tf32
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 16;
+    static constexpr index_t kN = 16;
+    static constexpr index_t kK = 8;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 16;
+    static constexpr index_t kBNLane     = 16;
+    static constexpr index_t kABKLane    = 4;
+    static constexpr index_t kABKPerLane = 2;
+
+    static constexpr index_t kCMLane     = 4;
+    static constexpr index_t kCNLane     = 16;
+    static constexpr index_t kCM0PerLane = 1;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+        DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x8_xf32", Ctrl)
+        else
+        {
+#if defined(__gfx942__)
+            c_vec = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_vec, b_vec, c_vec, 0, 0, 0);
+#else
+            ck_tile::ignore = c_vec;
+            ck_tile::ignore = a_vec;
+            ck_tile::ignore = b_vec;
+#endif
+        }
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+#if defined(__gfx942__)
+        return bit_cast(
+            __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_vec, b_vec, CVecType{0.f}, 0, 0, 0));
+#else
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+        return CVecType{0.f};
+#endif
+    }
+};
+
 // V_MFMA_F32_16x16x32_BF16
 template 
 struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
index d6c21e88b56..cdaad992720 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
@@ -35,6 +35,11 @@ struct Dispatcher;
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
+// tf32
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K8; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K4; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
 // fp16
 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
 template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
index 899221547f6..a1b43460c17 100644
--- a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
+++ b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
@@ -20,6 +20,12 @@ struct DataTypeTraits
     static constexpr const char* name = "fp32";
 };
 
+template <>
+struct DataTypeTraits
+{
+    static constexpr const char* name = "tf32";
+};
+
 template <>
 struct DataTypeTraits
 {

From 73ee4a4ecc1a72daea18192f8dafee70a8d78c0c Mon Sep 17 00:00:00 2001
From: yingluAMD 
Date: Fri, 9 Jan 2026 15:13:14 +0800
Subject: [PATCH 3/7] add tf32 support on gfx950

---
 .../ck_tile/03_gemm/gemm_basic_invoker.hpp    |  11 ++
 .../ck_tile/host/reference/reference_gemm.hpp |  69 ++++++-
 include/ck_tile/ops/gemm/warp/warp_gemm.hpp   |  23 +++
 .../warp/warp_gemm_attribute_mfma_impl.hpp    | 168 ++++++++++++++++++
 .../ops/gemm/warp/warp_gemm_dispatcher.hpp    |  28 +++
 5 files changed, 297 insertions(+), 2 deletions(-)

diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
index c9a412784e1..8636357e462 100644
--- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
+++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
@@ -26,6 +26,7 @@ struct BasicInvoker
         }
 
         constexpr bool is_fp32_input = std::is_same_v;
+        [[maybe_unused]] constexpr bool is_tf32_compute = std::is_same_v;
 
         // This part comes from the Codegen
         constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
@@ -40,6 +41,16 @@ struct BasicInvoker
         constexpr ck_tile::index_t M_Warp_Tile = 16;
         constexpr ck_tile::index_t N_Warp_Tile = 16;
         constexpr ck_tile::index_t K_Warp_Tile = 16;
+#elif defined(CK_GFX950_SUPPORT)
+        // gfx950: fp32 uses 16x16x16 tile (native MFMA)
+        //         tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation)
+        constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
+        constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
+        constexpr ck_tile::index_t K_Warp = 1;
+
+        constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
+        constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
+        constexpr ck_tile::index_t K_Warp_Tile = 16;
 #else
         constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2;
         constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2;
diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp
index 0074e0c5952..03ce42eb87f 100644
--- a/include/ck_tile/host/reference/reference_gemm.hpp
+++ b/include/ck_tile/host/reference/reference_gemm.hpp
@@ -8,6 +8,7 @@
 
 #include "ck_tile/core.hpp"
 #include "ck_tile/host/host_tensor.hpp"
+#include "ck_tile/host/device_prop.hpp"
 
 namespace ck_tile {
 
@@ -450,6 +451,8 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
     const std::size_t N = b_k_n.get_length(1);
     const std::size_t K = a_m_k.get_length(1);
 
+    const std::string device_name = ck_tile::get_device_name();
+
     auto f_mn = [&](auto m, auto n) {
         AccDataType v_acc = 0;
 
@@ -486,8 +489,34 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
 
             if constexpr(std::is_same_v)
             {
-                v_acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
-                                                            ck_tile::type_convert(v_b));
+                if(device_name == "gfx950")
+                {
+                    // gfx950: use 3x bf16 emulation
+                    bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                    bf16_t v_a_bf16_small =
+                        ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                    bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                    bf16_t v_b_bf16_small =
+                        ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                    v_acc += ck_tile::type_convert(v_a_bf16_big) *
+                                 ck_tile::type_convert(v_b_bf16_small) +
+                             ck_tile::type_convert(v_a_bf16_small) *
+                                 ck_tile::type_convert(v_b_bf16_big) +
+                             ck_tile::type_convert(v_a_bf16_big) *
+                                 ck_tile::type_convert(v_b_bf16_big);
+                }
+                else if(device_name == "gfx942" || device_name == "gfx940" || device_name == "gfx941")
+                {
+                    // gfx94x: use native tf32 (xf32)
+                    v_acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
+                                                                ck_tile::type_convert(v_b));
+                }
+                else
+                {
+                    // gfx11/gfx12 and others: tf32 not supported, use fp32 fallback
+                    v_acc += v_a * v_b;
+                }
             }
             else
             {
@@ -803,8 +832,26 @@ __global__ void naive_gemm_kernel(ADataType* A,
 
             if constexpr(std::is_same_v)
             {
+#ifdef CK_GFX950_SUPPORT
+                // gfx950: use 3x bf16 emulation
+                bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                bf16_t v_a_bf16_small =
+                    ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                bf16_t v_b_bf16_small =
+                    ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                acc += ck_tile::type_convert(v_a_bf16_big) *
+                           ck_tile::type_convert(v_b_bf16_small) +
+                       ck_tile::type_convert(v_a_bf16_small) *
+                           ck_tile::type_convert(v_b_bf16_big) +
+                       ck_tile::type_convert(v_a_bf16_big) *
+                           ck_tile::type_convert(v_b_bf16_big);
+#else
+                // gfx942 and others: use native tf32
                 acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
                                                           ck_tile::type_convert(v_b));
+#endif
             }
             else
             {
@@ -926,8 +973,26 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
 
             if constexpr(std::is_same_v)
             {
+#ifdef CK_GFX950_SUPPORT
+                // gfx950: use 3x bf16 emulation
+                bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                bf16_t v_a_bf16_small =
+                    ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                bf16_t v_b_bf16_small =
+                    ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                acc_temp += ck_tile::type_convert(v_a_bf16_big) *
+                                ck_tile::type_convert(v_b_bf16_small) +
+                            ck_tile::type_convert(v_a_bf16_small) *
+                                ck_tile::type_convert(v_b_bf16_big) +
+                            ck_tile::type_convert(v_a_bf16_big) *
+                                ck_tile::type_convert(v_b_bf16_big);
+#else
+                // gfx942 and others: use native tf32
                 acc_temp += ck_tile::type_convert(ck_tile::type_convert(v_a) *
                                                                ck_tile::type_convert(v_b));
+#endif
             }
             else
             {
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
index 35f1b1acfef..dd4b4968db2 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
@@ -31,7 +31,29 @@ using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
         AttrNumAccess>>;
 
 // tf32
+// On gfx942: uses native xf32 instructions
+// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
+// On gfx11/gfx12: tf32 not supported, these types are not defined
 
+#if defined(CK_GFX950_SUPPORT)
+// gfx950: tf32 emulated using 3x bf16 MFMA
+using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl<
+    WarpGemmAttributeMfma>>;
+
+using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl<
+    WarpGemmAttributeMfma>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M16N16K16 = WarpGemmImpl<
+    WarpGemmAttributeMfma,
+                          AttrNumAccess>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl<
+    WarpGemmAttributeMfma,
+                          AttrNumAccess>>;
+#else
+// gfx942: uses native xf32 instructions
 using WarpGemmMfmaTf32Tf32F32M16N16K8 = WarpGemmImpl<
     WarpGemmAttributeMfma>>;
 
@@ -49,6 +71,7 @@ using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl,
     4,
     AttrNumAccess>>;
+#endif
 
 // fp16
 
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
index a79011899a2..9ed2a043b5c 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
@@ -61,6 +61,19 @@ enum class WGAttrCtlEnum
         DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a")     \
     }
 
+// Helper function to convert float to bf16 pairs for tf32 emulation on gfx950
+// This is used to simulate tf32 using 3x bf16 MFMA: big*big + small*big + big*small
+template 
+CK_TILE_DEVICE void convert_float_to_bf16_pairs(const thread_buffer& reg_f32,
+                                                thread_buffer& reg_bf16_big,
+                                                thread_buffer& reg_bf16_small)
+{
+    static_for<0, VecSize, 1>{}([&](auto k) {
+        reg_bf16_big(k) = type_convert(reg_f32[k]);
+        reg_bf16_small(k) = type_convert(reg_f32[k] - type_convert(reg_bf16_big[k]));
+    });
+}
+
 // F32
 template 
 struct WarpGemmAttributeMfmaImplF32F32F32M16N16K4
@@ -318,6 +331,161 @@ struct WarpGemmAttributeMfmaImplF32F32F32M16N16K8Tf32
     }
 };
 
+// tf32/xf32 emulation on gfx950 using 3x bf16 MFMA
+// Algorithm: split float into bf16_big and bf16_small, then compute:
+//   out = A_big * B_big + A_small * B_big + A_big * B_small
+// This provides tf32-like precision using bf16 hardware
+
+// V_MFMA_F32_32x32x16_XF32 emulated on gfx950 using 3x bf16 32x32x16
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    // Input: 8 floats for K=16 (each lane holds 8 elements, kABKPerLane=8)
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 32;
+    static constexpr index_t kN = 32;
+    static constexpr index_t kK = 16;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 32;
+    static constexpr index_t kBNLane     = 32;
+    static constexpr index_t kABKLane    = 2;
+    static constexpr index_t kABKPerLane = 8;
+
+    static constexpr index_t kCMLane     = 2;
+    static constexpr index_t kCNLane     = 32;
+    static constexpr index_t kCM0PerLane = 4;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+#if defined(__gfx950__)
+        // Convert ext_vector to thread_buffer for element access
+        const auto& a_f32 = reinterpret_cast&>(a_vec);
+        const auto& b_f32 = reinterpret_cast&>(b_vec);
+
+        // Convert float to bf16 pairs
+        thread_buffer a_bf16_big, a_bf16_small, b_bf16_big, b_bf16_small;
+        convert_float_to_bf16_pairs(a_f32, a_bf16_big, a_bf16_small);
+        convert_float_to_bf16_pairs(b_f32, b_bf16_big, b_bf16_small);
+
+        // Get bf16x8 vectors for MFMA
+        auto a_big = bit_cast>(a_bf16_big);
+        auto a_small = bit_cast>(a_bf16_small);
+        auto b_big = bit_cast>(b_bf16_big);
+        auto b_small = bit_cast>(b_bf16_small);
+
+        // Run 3 bf16 MFMAs: small*big, big*small, big*big
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_small, b_big, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_small, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_big, c_vec, 0, 0, 0);
+#else
+        ck_tile::ignore = c_vec;
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+#endif
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+        CVecType c_vec{0.f};
+        (*this)(c_vec, a_vec, b_vec);
+        return c_vec;
+    }
+};
+
+// V_MFMA_F32_16x16x32_XF32 emulated on gfx950 using 3x bf16 16x16x32
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    // Input: 8 floats for K=32 (each lane holds 8 elements, kABKPerLane=8)
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 16;
+    static constexpr index_t kN = 16;
+    static constexpr index_t kK = 32;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 16;
+    static constexpr index_t kBNLane     = 16;
+    static constexpr index_t kABKLane    = 4;
+    static constexpr index_t kABKPerLane = 8;
+
+    static constexpr index_t kCMLane     = 4;
+    static constexpr index_t kCNLane     = 16;
+    static constexpr index_t kCM0PerLane = 1;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+#if defined(__gfx950__)
+        // Convert ext_vector to thread_buffer for element access
+        const auto& a_f32 = reinterpret_cast&>(a_vec);
+        const auto& b_f32 = reinterpret_cast&>(b_vec);
+
+        // Convert float to bf16 pairs
+        thread_buffer a_bf16_big, a_bf16_small, b_bf16_big, b_bf16_small;
+        convert_float_to_bf16_pairs(a_f32, a_bf16_big, a_bf16_small);
+        convert_float_to_bf16_pairs(b_f32, b_bf16_big, b_bf16_small);
+
+        // Get bf16x8 vectors for MFMA
+        auto a_big = bit_cast>(a_bf16_big);
+        auto a_small = bit_cast>(a_bf16_small);
+        auto b_big = bit_cast>(b_bf16_big);
+        auto b_small = bit_cast>(b_bf16_small);
+
+        // Run 3 bf16 MFMAs: small*big, big*small, big*big
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_small, b_big, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_small, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_big, c_vec, 0, 0, 0);
+#else
+        ck_tile::ignore = c_vec;
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+#endif
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+        CVecType c_vec{0.f};
+        (*this)(c_vec, a_vec, b_vec);
+        return c_vec;
+    }
+};
+
 // V_MFMA_F32_16x16x32_BF16
 template 
 struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
index cdaad992720..515e51ae7e6 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
@@ -35,11 +35,39 @@ struct Dispatcher;
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution; };
 // tf32
+// gfx942: uses native xf32 instructions (16x16x8, 32x32x4)
+// gfx950: uses 3x bf16 MFMA emulation (16x16x32, 32x32x16 native bf16 sizes)
+// gfx11/gfx12: tf32 not supported, dispatchers not defined
+#if defined(CK_GFX950_SUPPORT)
+// For tf32 on gfx950: epilogue uses float types but 32x32x16 tile
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+// On gfx950, tf32 uses bf16 emulation
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+#else
+// gfx942: uses native xf32 instructions
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K8; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K4; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16<>; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+#endif
+// Note: For gfx11/gfx12 and other architectures that don't support tf32,
+// these dispatchers are not defined. Code using tf32 should be guarded
+// by CK_ENABLE_TF32 or CK_GFX950_SUPPORT macros.
 // fp16
 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
 template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; };

From 633e4c188b94b0727369e694e288707adb855013 Mon Sep 17 00:00:00 2001
From: yingluAMD 
Date: Tue, 20 Jan 2026 14:14:44 +0800
Subject: [PATCH 4/7] remove gfx942 support

---
 example/ck_tile/03_gemm/CMakeLists.txt        |  2 +-
 example/ck_tile/03_gemm/gemm_basic.cpp        |  1 -
 .../ck_tile/03_gemm/gemm_basic_invoker.hpp    |  1 +
 .../ck_tile/host/reference/reference_gemm.hpp | 18 +++++-----------
 include/ck_tile/ops/gemm/warp/warp_gemm.hpp   | 21 -------------------
 .../ops/gemm/warp/warp_gemm_dispatcher.hpp    | 14 +------------
 6 files changed, 8 insertions(+), 49 deletions(-)

diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt
index 40547d07191..34b5cfe1f05 100644
--- a/example/ck_tile/03_gemm/CMakeLists.txt
+++ b/example/ck_tile/03_gemm/CMakeLists.txt
@@ -1,7 +1,7 @@
 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
 # SPDX-License-Identifier: MIT
 
-if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a")
+if(GPU_TARGETS MATCHES "gfx950")
   add_executable(tile_example_gemm_basic gemm_basic.cpp)
   add_executable(tile_example_gemm_universal universal_gemm.cpp)
   add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp)
diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp
index 2ca76623134..56a0cc69d3d 100644
--- a/example/ck_tile/03_gemm/gemm_basic.cpp
+++ b/example/ck_tile/03_gemm/gemm_basic.cpp
@@ -9,7 +9,6 @@
 
 int run_gemm_example(ck_tile::ArgParser& arg_parser)
 {
-    arg_parser.print();
 
     std::string data_type = arg_parser.get_str("prec");
     std::string a_layout  = arg_parser.get_str("a_layout");
diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
index 8636357e462..a164eb0f322 100644
--- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
+++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
@@ -52,6 +52,7 @@ struct BasicInvoker
         constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
         constexpr ck_tile::index_t K_Warp_Tile = 16;
 #else
+        // Fallback or other architectures
         constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2;
         constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2;
         constexpr ck_tile::index_t K_Warp = 1;
diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp
index 03ce42eb87f..7e4c6efd8da 100644
--- a/include/ck_tile/host/reference/reference_gemm.hpp
+++ b/include/ck_tile/host/reference/reference_gemm.hpp
@@ -506,15 +506,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
                              ck_tile::type_convert(v_a_bf16_big) *
                                  ck_tile::type_convert(v_b_bf16_big);
                 }
-                else if(device_name == "gfx942" || device_name == "gfx940" || device_name == "gfx941")
-                {
-                    // gfx94x: use native tf32 (xf32)
-                    v_acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
-                                                                ck_tile::type_convert(v_b));
-                }
                 else
                 {
-                    // gfx11/gfx12 and others: tf32 not supported, use fp32 fallback
+                    // Other architectures: tf32 not supported or handled via fp32 fallback
                     v_acc += v_a * v_b;
                 }
             }
@@ -848,9 +842,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
                        ck_tile::type_convert(v_a_bf16_big) *
                            ck_tile::type_convert(v_b_bf16_big);
 #else
-                // gfx942 and others: use native tf32
-                acc += ck_tile::type_convert(ck_tile::type_convert(v_a) *
-                                                          ck_tile::type_convert(v_b));
+                // Other architectures: use fp32 fallback
+                acc += v_a * v_b;
 #endif
             }
             else
@@ -989,9 +982,8 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
                             ck_tile::type_convert(v_a_bf16_big) *
                                 ck_tile::type_convert(v_b_bf16_big);
 #else
-                // gfx942 and others: use native tf32
-                acc_temp += ck_tile::type_convert(ck_tile::type_convert(v_a) *
-                                                               ck_tile::type_convert(v_b));
+                // Other architectures: use fp32 fallback
+                acc_temp += v_a * v_b;
 #endif
             }
             else
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
index dd4b4968db2..760705e9a87 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
@@ -31,9 +31,7 @@ using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
         AttrNumAccess>>;
 
 // tf32
-// On gfx942: uses native xf32 instructions
 // On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
-// On gfx11/gfx12: tf32 not supported, these types are not defined
 
 #if defined(CK_GFX950_SUPPORT)
 // gfx950: tf32 emulated using 3x bf16 MFMA
@@ -52,25 +50,6 @@ template 
 using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl<
     WarpGemmAttributeMfma,
                           AttrNumAccess>>;
-#else
-// gfx942: uses native xf32 instructions
-using WarpGemmMfmaTf32Tf32F32M16N16K8 = WarpGemmImpl<
-    WarpGemmAttributeMfma>>;
-
-using WarpGemmMfmaTf32Tf32F32M32N32K4 = WarpGemmImpl<
-    WarpGemmAttributeMfma>>;
-
-template 
-using WarpGemmMfmaTf32Tf32F32M16N16K16 = WarpGemmImpl,
-    2,
-    AttrNumAccess>>;
-
-template 
-using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl,
-    4,
-    AttrNumAccess>>;
 #endif
 
 // fp16
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
index 515e51ae7e6..b1ad89cc9fd 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
@@ -38,9 +38,7 @@ template<> struct Dispatcher { using Typ
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution; };
 // tf32
-// gfx942: uses native xf32 instructions (16x16x8, 32x32x4)
-// gfx950: uses 3x bf16 MFMA emulation (16x16x32, 32x32x16 native bf16 sizes)
-// gfx11/gfx12: tf32 not supported, dispatchers not defined
+// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
 #if defined(CK_GFX950_SUPPORT)
 // For tf32 on gfx950: epilogue uses float types but 32x32x16 tile
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
@@ -54,16 +52,6 @@ template<> struct Dispatcher struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
-#else
-// gfx942: uses native xf32 instructions
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K8; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K4; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16<>; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
-template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
 #endif
 // Note: For gfx11/gfx12 and other architectures that don't support tf32,
 // these dispatchers are not defined. Code using tf32 should be guarded

From 6ad40d7216a832e75ec501ba64317ea7ce612c99 Mon Sep 17 00:00:00 2001
From: yingluAMD 
Date: Tue, 20 Jan 2026 14:27:46 +0800
Subject: [PATCH 5/7] fix clang-format fail

---
 .../ck_tile/03_gemm/gemm_basic_invoker.hpp    |  3 ++-
 example/ck_tile/03_gemm/run_gemm_example.inc  |  3 ++-
 .../ck_tile/host/reference/reference_gemm.hpp | 12 +++++------
 ...emm_universal_pipeline_ag_bg_cr_policy.hpp |  8 ++++----
 include/ck_tile/ops/gemm/warp/warp_gemm.hpp   | 20 +++++++++----------
 .../warp/warp_gemm_attribute_mfma_impl.hpp    | 10 +++++-----
 6 files changed, 29 insertions(+), 27 deletions(-)

diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
index a164eb0f322..2352bcf5a86 100644
--- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
+++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp
@@ -26,7 +26,8 @@ struct BasicInvoker
         }
 
         constexpr bool is_fp32_input = std::is_same_v;
-        [[maybe_unused]] constexpr bool is_tf32_compute = std::is_same_v;
+        [[maybe_unused]] constexpr bool is_tf32_compute =
+            std::is_same_v;
 
         // This part comes from the Codegen
         constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc
index cb55a6bce19..ef65b141529 100644
--- a/example/ck_tile/03_gemm/run_gemm_example.inc
+++ b/example/ck_tile/03_gemm/run_gemm_example.inc
@@ -433,7 +433,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                     ALayout,
                                     BLayout,
                                     CLayout,
-                                    ComputeDataType>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
+                                    ComputeDataType>(
+            d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
 
         c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
 
diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp
index 7e4c6efd8da..566dff7df9f 100644
--- a/include/ck_tile/host/reference/reference_gemm.hpp
+++ b/include/ck_tile/host/reference/reference_gemm.hpp
@@ -492,12 +492,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
                 if(device_name == "gfx950")
                 {
                     // gfx950: use 3x bf16 emulation
-                    bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
-                    bf16_t v_a_bf16_small =
-                        ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
-                    bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
-                    bf16_t v_b_bf16_small =
-                        ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+                    bf16_t v_a_bf16_big   = ck_tile::type_convert(v_a);
+                    bf16_t v_a_bf16_small = ck_tile::type_convert(
+                        v_a - type_convert(v_a_bf16_big));
+                    bf16_t v_b_bf16_big   = ck_tile::type_convert(v_b);
+                    bf16_t v_b_bf16_small = ck_tile::type_convert(
+                        v_b - type_convert(v_b_bf16_big));
 
                     v_acc += ck_tile::type_convert(v_a_bf16_big) *
                                  ck_tile::type_convert(v_b_bf16_small) +
diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
index b3531496fe2..1ecc616a90c 100644
--- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
@@ -160,7 +160,7 @@ struct UniversalGemmBasePolicy
                 constexpr auto K0PerThreadRead  = (AK0 / KThreadRead) > 0 ? (AK0 / KThreadRead) : 1;
 
                 // check if we exceed all LDS banks
-                constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
+                constexpr auto LdsBanksWidth  = get_n_lds_banks() * get_n_words_per_128b();
                 constexpr auto kfold          = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth ||
                                         (AK1 * M0 * sizeof(ADataType)) == 0)
                                                     ? 1
@@ -358,7 +358,7 @@ struct UniversalGemmBasePolicy
                 constexpr auto K0PerThreadRead  = (BK0 / KThreadRead) > 0 ? (BK0 / KThreadRead) : 1;
 
                 // check if we exceed all LDS banks
-                constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
+                constexpr auto LdsBanksWidth  = get_n_lds_banks() * get_n_words_per_128b();
                 constexpr auto kfold          = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth ||
                                         (BK1 * N0 * sizeof(BDataType)) == 0)
                                                     ? 1
@@ -897,8 +897,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
             : vector_size * 4 == thread_elements              ? WGAttrNumAccessEnum::Quad
                                                               : WGAttrNumAccessEnum::Invalid;
 
-        using ADataType = remove_cvref_t;
-        using BDataType = remove_cvref_t;
+        using ADataType       = remove_cvref_t;
+        using BDataType       = remove_cvref_t;
         using ComputeDataType = remove_cvref_t;
 
         using ATypeToUse =
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
index 760705e9a87..52fc853bc58 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
@@ -35,21 +35,21 @@ using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
 
 #if defined(CK_GFX950_SUPPORT)
 // gfx950: tf32 emulated using 3x bf16 MFMA
-using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl<
-    WarpGemmAttributeMfma>>;
+using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl>>;
 
-using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl<
-    WarpGemmAttributeMfma>>;
+using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl>>;
 
 template 
-using WarpGemmMfmaTf32Tf32F32M16N16K16 = WarpGemmImpl<
-    WarpGemmAttributeMfma,
-                          AttrNumAccess>>;
+using WarpGemmMfmaTf32Tf32F32M16N16K16 = WarpGemmImpl,
+    AttrNumAccess>>;
 
 template 
-using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl<
-    WarpGemmAttributeMfma,
-                          AttrNumAccess>>;
+using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl,
+    AttrNumAccess>>;
 #endif
 
 // fp16
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
index 9ed2a043b5c..982cb499672 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
@@ -69,7 +69,7 @@ CK_TILE_DEVICE void convert_float_to_bf16_pairs(const thread_buffer& reg_bf16_small)
 {
     static_for<0, VecSize, 1>{}([&](auto k) {
-        reg_bf16_big(k) = type_convert(reg_f32[k]);
+        reg_bf16_big(k)   = type_convert(reg_f32[k]);
         reg_bf16_small(k) = type_convert(reg_f32[k] - type_convert(reg_bf16_big[k]));
     });
 }
@@ -386,9 +386,9 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950
         convert_float_to_bf16_pairs(b_f32, b_bf16_big, b_bf16_small);
 
         // Get bf16x8 vectors for MFMA
-        auto a_big = bit_cast>(a_bf16_big);
+        auto a_big   = bit_cast>(a_bf16_big);
         auto a_small = bit_cast>(a_bf16_small);
-        auto b_big = bit_cast>(b_bf16_big);
+        auto b_big   = bit_cast>(b_bf16_big);
         auto b_small = bit_cast>(b_bf16_small);
 
         // Run 3 bf16 MFMAs: small*big, big*small, big*big
@@ -461,9 +461,9 @@ struct WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950
         convert_float_to_bf16_pairs(b_f32, b_bf16_big, b_bf16_small);
 
         // Get bf16x8 vectors for MFMA
-        auto a_big = bit_cast>(a_bf16_big);
+        auto a_big   = bit_cast>(a_bf16_big);
         auto a_small = bit_cast>(a_bf16_small);
-        auto b_big = bit_cast>(b_bf16_big);
+        auto b_big   = bit_cast>(b_bf16_big);
         auto b_small = bit_cast>(b_bf16_small);
 
         // Run 3 bf16 MFMAs: small*big, big*small, big*big

From e584468613c11f214c3958f99a8e2c5e8ea4f0a2 Mon Sep 17 00:00:00 2001
From: yingluAMD 
Date: Wed, 21 Jan 2026 15:28:54 +0800
Subject: [PATCH 6/7] bug fix

---
 example/ck_tile/03_gemm/CMakeLists.txt        |  2 +-
 example/ck_tile/03_gemm/gemm_basic.cpp        |  3 ++-
 .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 19 ++++++++++++-------
 .../gemm_weight_preshuffle_invoker.hpp        | 19 ++++++++++++-------
 .../03_gemm/universal_gemm_invoker.hpp        | 19 ++++++++++++-------
 include/ck_tile/core/numeric/type_convert.hpp |  8 ++++----
 6 files changed, 43 insertions(+), 27 deletions(-)

diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt
index 34b5cfe1f05..40547d07191 100644
--- a/example/ck_tile/03_gemm/CMakeLists.txt
+++ b/example/ck_tile/03_gemm/CMakeLists.txt
@@ -1,7 +1,7 @@
 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
 # SPDX-License-Identifier: MIT
 
-if(GPU_TARGETS MATCHES "gfx950")
+if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a")
   add_executable(tile_example_gemm_basic gemm_basic.cpp)
   add_executable(tile_example_gemm_universal universal_gemm.cpp)
   add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp)
diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp
index 56a0cc69d3d..285161cc816 100644
--- a/example/ck_tile/03_gemm/gemm_basic.cpp
+++ b/example/ck_tile/03_gemm/gemm_basic.cpp
@@ -9,7 +9,6 @@
 
 int run_gemm_example(ck_tile::ArgParser& arg_parser)
 {
-
     std::string data_type = arg_parser.get_str("prec");
     std::string a_layout  = arg_parser.get_str("a_layout");
     std::string b_layout  = arg_parser.get_str("b_layout");
@@ -47,6 +46,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
         return run_gemm_example_prec_type(
             a_layout, b_layout, arg_parser);
     }
+#ifdef CK_GFX950_SUPPORT
     else if(data_type == "tf32")
     {
         return run_gemm_example_prec_type(a_layout, b_layout, arg_parser);
     }
+#endif
     else if(data_type == "fp8")
     {
         return run_gemm_example_prec_type
+              typename CDEElementWise,
+              typename ComputeDataType = ADataType>
     static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
 
     {
@@ -61,12 +62,16 @@ struct SplitKTwoStageInvoker
                                              GemmConfig::Preshuffle>;
         constexpr auto scheduler = GemmConfig::Scheduler;
 
-        using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem;
+        using UniversalGemmProblem =
+            ck_tile::UniversalGemmPipelineProblem;
         using WorkspaceType        = ck_tile::remove_cvref_t;
 
         using GemmPipeline = typename PipelineTypeTraits<
diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
index 1deafb97a17..b33067482b4 100644
--- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
+++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
@@ -16,7 +16,8 @@ struct WeightPreshuffleInvoker
               typename DsLayout,
               typename ELayout,
               bool Persistent,
-              typename CDEElementWise>
+              typename CDEElementWise,
+              typename ComputeDataType = ADataType>
     static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
 
     {
@@ -48,12 +49,16 @@ struct WeightPreshuffleInvoker
                                              GemmConfig::Preshuffle>;
         constexpr auto scheduler = GemmConfig::Scheduler;
 
-        using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem;
+        using UniversalGemmProblem =
+            ck_tile::UniversalGemmPipelineProblem;
 
         using GemmPipeline = typename PipelineTypeTraits<
             GemmConfig::Pipeline>::template GemmPipeline;
diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp
index fb89e6b4cc4..e230db38510 100644
--- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp
+++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp
@@ -17,7 +17,8 @@ struct UniversalInvoker
               typename DsLayout,
               typename ELayout,
               bool Persistent,
-              typename CDEElementWise>
+              typename CDEElementWise,
+              typename ComputeDataType = ADataType>
     static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
 
     {
@@ -50,12 +51,16 @@ struct UniversalInvoker
 
         constexpr auto scheduler = GemmConfig::Scheduler;
 
-        using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem;
+        using UniversalGemmProblem =
+            ck_tile::UniversalGemmPipelineProblem;
 
         using GemmPipeline = typename PipelineTypeTraits<
             GemmConfig::Pipeline>::template GemmPipeline;
diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp
index b1a9ce6e58d..7268f52e79e 100644
--- a/include/ck_tile/core/numeric/type_convert.hpp
+++ b/include/ck_tile/core/numeric/type_convert.hpp
@@ -59,15 +59,15 @@ CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
 
 enum class tf32_rounding_mode
 {
-    truncate = 0,
-    standard = 1, // RTNE
+    trunc = 0, // truncate
+    rne   = 1, // round to nearest even (RTNE)
 };
 
-template 
+template 
 CK_TILE_HOST_DEVICE constexpr float float_to_tf32(float x)
 {
     uint32_t i = bit_cast(x);
-    if constexpr(rounding == tf32_rounding_mode::standard)
+    if constexpr(rounding == tf32_rounding_mode::rne)
     {
         if((i & 0x7f800000) != 0x7f800000)
         {

From c67a7587a486d7814260ca11bfb639ccc0729623 Mon Sep 17 00:00:00 2001
From: yingluAMD 
Date: Wed, 21 Jan 2026 15:32:21 +0800
Subject: [PATCH 7/7] fix clang-format fail

---
 example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
index 4184a93ffdd..247a6d5f1c5 100644
--- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
+++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
@@ -72,7 +72,7 @@ struct SplitKTwoStageInvoker
                                                   ck_tile::element_wise::PassThrough,
                                                   ck_tile::element_wise::PassThrough,
                                                   ComputeDataType>;
-        using WorkspaceType        = ck_tile::remove_cvref_t;
+        using WorkspaceType = ck_tile::remove_cvref_t;
 
         using GemmPipeline = typename PipelineTypeTraits<
             GemmConfig::Pipeline>::template GemmPipeline;