diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 40547d07191..34b5cfe1f05 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") +if(GPU_TARGETS MATCHES "gfx950") add_executable(tile_example_gemm_basic gemm_basic.cpp) add_executable(tile_example_gemm_universal universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e30ae8319f2..56a0cc69d3d 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -9,6 +9,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) { + std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); @@ -41,6 +42,20 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) return run_gemm_example_prec_type( a_layout, b_layout, arg_parser); } + else if(data_type == "fp32") + { + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); + } + else if(data_type == "tf32") + { + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); + } else if(data_type == "fp8") { return run_gemm_example_prec_type + typename CDEElementWise, + typename ComputeDataType = ADataType> static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) @@ -24,9 +25,13 @@ struct BasicInvoker std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; } + constexpr bool is_fp32_input = std::is_same_v; + [[maybe_unused]] constexpr bool is_tf32_compute = + std::is_same_v; + // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256; + constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256; constexpr ck_tile::index_t K_Tile = 64; #if CK_TILE_USE_WMMA @@ -37,13 +42,24 @@ struct BasicInvoker constexpr ck_tile::index_t M_Warp_Tile = 16; constexpr ck_tile::index_t N_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = 16; +#elif defined(CK_GFX950_SUPPORT) + // gfx950: fp32 uses 16x16x16 tile (native MFMA) + // tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation) + constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2; + constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; #else - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; + // Fallback or other architectures + constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2; + constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = is_fp32_input ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_fp32_input ? 16 : 32; constexpr ck_tile::index_t K_Warp_Tile = 16; #endif @@ -61,11 +77,15 @@ struct BasicInvoker BLayout, CLayout>; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 8eff0e7469f..3b8ecde15e9 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -302,6 +302,24 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill
 struct GemmTypeConfig;
 
+template <>
+struct GemmTypeConfig
+{
+    using ADataType   = float;
+    using BDataType   = float;
+    using AccDataType = float;
+    using CDataType   = float;
+};
+
+template <>
+struct GemmTypeConfig
+{
+    using ADataType   = float;
+    using BDataType   = float;
+    using AccDataType = float;
+    using CDataType   = float;
+};
+
 template <>
 struct GemmTypeConfig
 {
@@ -446,7 +464,7 @@ inline auto create_args()
         .insert("stride_b", "0", "Tensor B stride")
         .insert("stride_c", "0", "Tensor C stride")
         .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
-        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
+        .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32")
         .insert("warmup", "50", "number of iterations before benchmark the kernel")
         .insert("repeat", "100", "number of iterations to benchmark the kernel")
         .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc
index 78f3a9b0b3f..ef65b141529 100644
--- a/example/ck_tile/03_gemm/run_gemm_example.inc
+++ b/example/ck_tile/03_gemm/run_gemm_example.inc
@@ -109,7 +109,8 @@ template 
+          typename ComputeDataType = ADataType,
+          typename CDEElementWise  = ck_tile::element_wise::PassThrough>
 float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                   ck_tile::DeviceMem& b_k_n_dev_buf,
                   ck_tile::DeviceMem& c_m_n_dev_buf,
@@ -151,7 +152,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                                           DsLayout,
                                           CLayout,
                                           true,
-                                          CDEElementWise>(
+                                          CDEElementWise,
+                                          ComputeDataType>(
             args,
             ck_tile::stream_config{
                 nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
@@ -169,7 +171,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
                                           DsLayout,
                                           CLayout,
                                           false,
-                                          CDEElementWise>(
+                                          CDEElementWise,
+                                          ComputeDataType>(
             args,
             ck_tile::stream_config{
                 nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
@@ -209,11 +212,12 @@ std::tuple inline parse_ge
 template 
+          typename BDataType       = ADataType,
+          typename CDataType       = ADataType,
+          typename ALayout         = ck_tile::tensor_layout::gemm::RowMajor,
+          typename BLayout         = ck_tile::tensor_layout::gemm::ColumnMajor,
+          typename CLayout         = ck_tile::tensor_layout::gemm::RowMajor,
+          typename ComputeDataType = ADataType>
 int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                   const ALayout a_layout                  = ALayout{},
                                   const BLayout b_layout                  = BLayout{},
@@ -349,21 +353,22 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                  ALayout,
                                  BLayout,
                                  ck_tile::tuple<>,
-                                 CLayout>(a_m_k_dev_buf,
-                                          b_k_n_dev_buf,
-                                          c_m_n_dev_buf,
-                                          M,
-                                          N,
-                                          K,
-                                          stride_A,
-                                          stride_B,
-                                          stride_C,
-                                          kbatch,
-                                          n_warmup,
-                                          n_repeat,
-                                          persistent,
-                                          flush_cache,
-                                          rotating_count);
+                                 CLayout,
+                                 ComputeDataType>(a_m_k_dev_buf,
+                                                  b_k_n_dev_buf,
+                                                  c_m_n_dev_buf,
+                                                  M,
+                                                  N,
+                                                  K,
+                                                  stride_A,
+                                                  stride_B,
+                                                  stride_C,
+                                                  kbatch,
+                                                  n_warmup,
+                                                  n_repeat,
+                                                  persistent,
+                                                  flush_cache,
+                                                  rotating_count);
 
     c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
 
@@ -393,7 +398,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
 
     if(arg_parser.get_int("v") == 1)
     {
-        ck_tile::reference_gemm(
+        ck_tile::reference_gemm(
             a_m_k, b_k_n, c_m_n_ref);
         const float max_accumulated_value =
             *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
@@ -427,7 +432,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
                                     CDataType,
                                     ALayout,
                                     BLayout,
-                                    CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
+                                    CLayout,
+                                    ComputeDataType>(
+            d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
 
         c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
 
diff --git a/example/ck_tile/03_gemm/run_gemm_example_common.hpp b/example/ck_tile/03_gemm/run_gemm_example_common.hpp
index e6a1c626e5a..e65328248ac 100644
--- a/example/ck_tile/03_gemm/run_gemm_example_common.hpp
+++ b/example/ck_tile/03_gemm/run_gemm_example_common.hpp
@@ -6,8 +6,9 @@
 template 
+          typename BPrecType       = APrecType,
+          typename CPrecType       = APrecType,
+          typename ComputeDataType = APrecType>
 int run_gemm_example_prec_type(std::string a_layout,
                                std::string b_layout,
                                ck_tile::ArgParser& arg_parser)
@@ -54,7 +55,11 @@ int run_gemm_example_prec_type(std::string a_layout,
                                                      Invoker,
                                                      APrecType,
                                                      BPrecType,
-                                                     CPrecType>(
+                                                     CPrecType,
+                                                     decltype(a_layout_type),
+                                                     decltype(b_layout_type),
+                                                     Row,
+                                                     ComputeDataType>(
                     arg_parser, a_layout_type, b_layout_type, Row{});
             }
         },
diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp
index b2bd6286851..e581cae0ab9 100644
--- a/include/ck_tile/core/numeric/numeric.hpp
+++ b/include/ck_tile/core/numeric/numeric.hpp
@@ -9,6 +9,8 @@
 
 namespace ck_tile {
 
+using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits
+
 // this struct has the information of
 // 1. limit of a certain type, simliar to std::numeric_limits
 // 2. some pre-defined value, zero, one...
@@ -101,6 +103,25 @@ struct numeric_traits
     using bitwise_type                  = uint32_t;
 };
 
+template <>
+struct numeric_traits
+{
+    static constexpr int exp            = 8;
+    static constexpr int mant           = 10;
+    static constexpr int bias           = 127;
+    static constexpr uint32_t nan_mask  = 0x7F800000;
+    static constexpr uint32_t head_mask = 0xFF800000;
+    static constexpr uint32_t mant_mask = 0x7FFFFF;
+    static constexpr uint32_t exp_mask  = 0xFF;
+    static constexpr uint32_t abs_mask  = 0x7FFFFFFF;
+    static constexpr uint32_t Inf       = 0x7F800000;
+    static constexpr uint32_t NegInf    = 0xFF800000;
+    static constexpr uint32_t NaN       = 0x7F800001;
+    static constexpr uint32_t Neg0      = 0x80000000;
+    static constexpr int PackedSize     = 1;
+    using bitwise_type                  = uint32_t;
+};
+
 } // namespace ck_tile
 
 #define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)                                       \
diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp
index deaa9e0bd90..b1a9ce6e58d 100644
--- a/include/ck_tile/core/numeric/type_convert.hpp
+++ b/include/ck_tile/core/numeric/type_convert.hpp
@@ -57,6 +57,33 @@ CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
 CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
 CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
 
+enum class tf32_rounding_mode
+{
+    truncate = 0,
+    standard = 1, // RTNE
+};
+
+template 
+CK_TILE_HOST_DEVICE constexpr float float_to_tf32(float x)
+{
+    uint32_t i = bit_cast(x);
+    if constexpr(rounding == tf32_rounding_mode::standard)
+    {
+        if((i & 0x7f800000) != 0x7f800000)
+        {
+            i += 0xfff + ((i >> 13) & 1);
+        }
+    }
+    i &= 0xFFFFE000u;
+    return bit_cast(i);
+}
+
+template , bool> = false>
+CK_TILE_HOST_DEVICE constexpr float type_convert(float x)
+{
+    return float_to_tf32(x);
+}
+
 CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
 CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
 CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp
index 9ad5af8264c..566dff7df9f 100644
--- a/include/ck_tile/host/reference/reference_gemm.hpp
+++ b/include/ck_tile/host/reference/reference_gemm.hpp
@@ -8,6 +8,7 @@
 
 #include "ck_tile/core.hpp"
 #include "ck_tile/host/host_tensor.hpp"
+#include "ck_tile/host/device_prop.hpp"
 
 namespace ck_tile {
 
@@ -435,9 +436,10 @@ template 
+          typename ComputeDataType = ADataType,
+          typename AElementOp      = ck_tile::identity,
+          typename BElementOp      = ck_tile::identity,
+          typename ACCElementOp    = ck_tile::identity>
 CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
                                  const HostTensor& b_k_n,
                                  HostTensor& c_m_n,
@@ -449,6 +451,8 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
     const std::size_t N = b_k_n.get_length(1);
     const std::size_t K = a_m_k.get_length(1);
 
+    const std::string device_name = ck_tile::get_device_name();
+
     auto f_mn = [&](auto m, auto n) {
         AccDataType v_acc = 0;
 
@@ -482,7 +486,36 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k,
             {
                 v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n)));
             }
-            v_acc += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+                if(device_name == "gfx950")
+                {
+                    // gfx950: use 3x bf16 emulation
+                    bf16_t v_a_bf16_big   = ck_tile::type_convert(v_a);
+                    bf16_t v_a_bf16_small = ck_tile::type_convert(
+                        v_a - type_convert(v_a_bf16_big));
+                    bf16_t v_b_bf16_big   = ck_tile::type_convert(v_b);
+                    bf16_t v_b_bf16_small = ck_tile::type_convert(
+                        v_b - type_convert(v_b_bf16_big));
+
+                    v_acc += ck_tile::type_convert(v_a_bf16_big) *
+                                 ck_tile::type_convert(v_b_bf16_small) +
+                             ck_tile::type_convert(v_a_bf16_small) *
+                                 ck_tile::type_convert(v_b_bf16_big) +
+                             ck_tile::type_convert(v_a_bf16_big) *
+                                 ck_tile::type_convert(v_b_bf16_big);
+                }
+                else
+                {
+                    // Other architectures: tf32 not supported or handled via fp32 fallback
+                    v_acc += v_a * v_b;
+                }
+            }
+            else
+            {
+                v_acc += v_a * v_b;
+            }
         }
 
         c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc));
@@ -717,7 +750,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 __global__ void naive_gemm_kernel(ADataType* A,
                                   BDataType* B,
                                   CDataType* C,
@@ -789,7 +823,33 @@ __global__ void naive_gemm_kernel(ADataType* A,
             {
                 v_b = ck_tile::type_convert(B[b_index]);
             }
-            acc += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+#ifdef CK_GFX950_SUPPORT
+                // gfx950: use 3x bf16 emulation
+                bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                bf16_t v_a_bf16_small =
+                    ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                bf16_t v_b_bf16_small =
+                    ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                acc += ck_tile::type_convert(v_a_bf16_big) *
+                           ck_tile::type_convert(v_b_bf16_small) +
+                       ck_tile::type_convert(v_a_bf16_small) *
+                           ck_tile::type_convert(v_b_bf16_big) +
+                       ck_tile::type_convert(v_a_bf16_big) *
+                           ck_tile::type_convert(v_b_bf16_big);
+#else
+                // Other architectures: use fp32 fallback
+                acc += v_a * v_b;
+#endif
+            }
+            else
+            {
+                acc += v_a * v_b;
+            }
         }
 
         int c_index = (std::is_same_v)
@@ -805,7 +865,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 __global__ void blockwise_gemm_kernel(ADataType* A,
                                       BDataType* B,
                                       CDataType* C,
@@ -902,7 +963,33 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
             {
                 v_b = ck_tile::type_convert(B[b_index]);
             }
-            acc_temp += v_a * v_b;
+
+            if constexpr(std::is_same_v)
+            {
+#ifdef CK_GFX950_SUPPORT
+                // gfx950: use 3x bf16 emulation
+                bf16_t v_a_bf16_big = ck_tile::type_convert(v_a);
+                bf16_t v_a_bf16_small =
+                    ck_tile::type_convert(v_a - type_convert(v_a_bf16_big));
+                bf16_t v_b_bf16_big = ck_tile::type_convert(v_b);
+                bf16_t v_b_bf16_small =
+                    ck_tile::type_convert(v_b - type_convert(v_b_bf16_big));
+
+                acc_temp += ck_tile::type_convert(v_a_bf16_big) *
+                                ck_tile::type_convert(v_b_bf16_small) +
+                            ck_tile::type_convert(v_a_bf16_small) *
+                                ck_tile::type_convert(v_b_bf16_big) +
+                            ck_tile::type_convert(v_a_bf16_big) *
+                                ck_tile::type_convert(v_b_bf16_big);
+#else
+                // Other architectures: use fp32 fallback
+                acc_temp += v_a * v_b;
+#endif
+            }
+            else
+            {
+                acc_temp += v_a * v_b;
+            }
         }
         // final accumulation
         acc += acc_temp * scale_A * scale_B;
@@ -920,7 +1007,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 void reference_gemm_gpu(ADataType* a_ptr,
                         BDataType* b_ptr,
                         CDataType* c_ptr,
@@ -935,9 +1023,15 @@ void reference_gemm_gpu(ADataType* a_ptr,
     int numThreadsPerBlock = 256; // Common choice for threads per block
     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
-    naive_gemm_kernel
-        <<>>(
-            a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
+    naive_gemm_kernel<<>>(
+        a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
 
     return;
 }
@@ -948,7 +1042,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 void reference_blockwise_gemm_gpu(ADataType* a_ptr,
                                   BDataType* b_ptr,
                                   CDataType* c_ptr,
@@ -968,21 +1063,27 @@ void reference_blockwise_gemm_gpu(ADataType* a_ptr,
     int numThreadsPerBlock = 256; // Common choice for threads per block
     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
-    blockwise_gemm_kernel
-        <<>>(a_ptr,
-                                            b_ptr,
-                                            c_ptr,
-                                            M,
-                                            N,
-                                            K,
-                                            stride_a,
-                                            stride_b,
-                                            stride_c,
-                                            scale_granularity_m,
-                                            scale_granularity_n,
-                                            scale_granularity_k,
-                                            scale_A_ptr,
-                                            scale_B_ptr);
+    blockwise_gemm_kernel<<>>(a_ptr,
+                                                                              b_ptr,
+                                                                              c_ptr,
+                                                                              M,
+                                                                              N,
+                                                                              K,
+                                                                              stride_a,
+                                                                              stride_b,
+                                                                              stride_c,
+                                                                              scale_granularity_m,
+                                                                              scale_granularity_n,
+                                                                              scale_granularity_k,
+                                                                              scale_A_ptr,
+                                                                              scale_B_ptr);
 
     return;
 }
@@ -993,7 +1094,8 @@ template 
+          typename LayoutC,
+          typename ComputeDataType = ADataType>
 void reference_batched_gemm_gpu(ADataType* a_ptr,
                                 BDataType* b_ptr,
                                 CDataType* c_ptr,
@@ -1017,9 +1119,15 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
         ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
         BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
         CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
-        naive_gemm_kernel
-            <<>>(
-                d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
+        naive_gemm_kernel<<>>(
+            d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
     }
 
     return;
diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
index 6199142d986..1ecc616a90c 100644
--- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
@@ -157,17 +157,18 @@ struct UniversalGemmBasePolicy
                 constexpr auto KThreadWrite     = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
                 constexpr auto K0PerThreadWrite = AK0 / KThreadWrite;
                 constexpr auto KThreadRead      = get_warp_size() / MPerXdl;
-                constexpr auto K0PerThreadRead  = AK0 / KThreadRead;
+                constexpr auto K0PerThreadRead  = (AK0 / KThreadRead) > 0 ? (AK0 / KThreadRead) : 1;
 
                 // check if we exceed all LDS banks
-                constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
-                constexpr auto kfold         = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
-                                                   ? 1
-                                                   : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
+                constexpr auto LdsBanksWidth  = get_n_lds_banks() * get_n_words_per_128b();
+                constexpr auto kfold          = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth ||
+                                        (AK1 * M0 * sizeof(ADataType)) == 0)
+                                                    ? 1
+                                                    : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
+                constexpr auto divisor        = (kfold * K0PerThreadWrite / K0PerThreadRead);
+                constexpr auto divisor_to_use = divisor > 0 ? divisor : 1;
                 constexpr auto KThreadReadPerm =
-                    (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
-                        ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
-                        : KThreadRead;
+                    (divisor > 1) ? KThreadRead / divisor_to_use : KThreadRead;
 
                 // 1<=mpair<=n0
                 constexpr auto mpair =
@@ -354,17 +355,18 @@ struct UniversalGemmBasePolicy
                 constexpr auto KThreadWrite     = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
                 constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
                 constexpr auto KThreadRead      = get_warp_size() / NPerXdl;
-                constexpr auto K0PerThreadRead  = BK0 / KThreadRead;
+                constexpr auto K0PerThreadRead  = (BK0 / KThreadRead) > 0 ? (BK0 / KThreadRead) : 1;
 
                 // check if we exceed all LDS banks
-                constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
-                constexpr auto kfold         = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
-                                                   ? 1
-                                                   : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
+                constexpr auto LdsBanksWidth  = get_n_lds_banks() * get_n_words_per_128b();
+                constexpr auto kfold          = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth ||
+                                        (BK1 * N0 * sizeof(BDataType)) == 0)
+                                                    ? 1
+                                                    : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
+                constexpr auto divisor        = (kfold * K0PerThreadWrite / K0PerThreadRead);
+                constexpr auto divisor_to_use = divisor > 0 ? divisor : 1;
                 constexpr auto KThreadReadPerm =
-                    (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
-                        ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
-                        : KThreadRead;
+                    (divisor > 1) ? KThreadRead / divisor_to_use : KThreadRead;
 
                 // 1<=npair<=n0
                 constexpr auto npair =
@@ -895,8 +897,10 @@ struct UniversalGemmPipelineAgBgCrPolicy
             : vector_size * 4 == thread_elements              ? WGAttrNumAccessEnum::Quad
                                                               : WGAttrNumAccessEnum::Invalid;
 
-        using ADataType = remove_cvref_t;
-        using BDataType = remove_cvref_t;
+        using ADataType       = remove_cvref_t;
+        using BDataType       = remove_cvref_t;
+        using ComputeDataType = remove_cvref_t;
+
         using ATypeToUse =
             std::conditional_t, BDataType, ADataType>;
         using BTypeToUse = std::conditional_t ||
@@ -904,8 +908,13 @@ struct UniversalGemmPipelineAgBgCrPolicy
                                               ADataType,
                                               BDataType>;
 
-        using WarpGemm = WarpGemmDispatcher, tf32_t, ATypeToUse>;
+        using BTypeForDispatcher =
+            std::conditional_t, tf32_t, BTypeToUse>;
+
+        using WarpGemm = WarpGemmDispatcher>;
 
+// tf32
+// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
+
+#if defined(CK_GFX950_SUPPORT)
+// gfx950: tf32 emulated using 3x bf16 MFMA
+using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl>>;
+
+using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M16N16K16 = WarpGemmImpl,
+    AttrNumAccess>>;
+
+template 
+using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl,
+    AttrNumAccess>>;
+#endif
+
 // fp16
 
 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
index bd65f533839..982cb499672 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
@@ -61,6 +61,19 @@ enum class WGAttrCtlEnum
         DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a")     \
     }
 
+// Helper function to convert float to bf16 pairs for tf32 emulation on gfx950
+// This is used to simulate tf32 using 3x bf16 MFMA: big*big + small*big + big*small
+template 
+CK_TILE_DEVICE void convert_float_to_bf16_pairs(const thread_buffer& reg_f32,
+                                                thread_buffer& reg_bf16_big,
+                                                thread_buffer& reg_bf16_small)
+{
+    static_for<0, VecSize, 1>{}([&](auto k) {
+        reg_bf16_big(k)   = type_convert(reg_f32[k]);
+        reg_bf16_small(k) = type_convert(reg_f32[k] - type_convert(reg_bf16_big[k]));
+    });
+}
+
 // F32
 template 
 struct WarpGemmAttributeMfmaImplF32F32F32M16N16K4
@@ -190,6 +203,289 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
     }
 };
 
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M32N32K4Tf32
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 32;
+    static constexpr index_t kN = 32;
+    static constexpr index_t kK = 4;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 32;
+    static constexpr index_t kBNLane     = 32;
+    static constexpr index_t kABKLane    = 2;
+    static constexpr index_t kABKPerLane = 2;
+
+    static constexpr index_t kCMLane     = 2;
+    static constexpr index_t kCNLane     = 32;
+    static constexpr index_t kCM0PerLane = 4;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+        DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x4_xf32", Ctrl)
+        else
+        {
+#if defined(__gfx942__)
+            c_vec = __builtin_amdgcn_mfma_f32_32x32x4_xf32(a_vec, b_vec, c_vec, 0, 0, 0);
+#else
+            ck_tile::ignore = c_vec;
+            ck_tile::ignore = a_vec;
+            ck_tile::ignore = b_vec;
+#endif
+        }
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+#if defined(__gfx942__)
+        return bit_cast(
+            __builtin_amdgcn_mfma_f32_32x32x4_xf32(a_vec, b_vec, CVecType{0.f}, 0, 0, 0));
+#else
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+        return CVecType{0.f};
+#endif
+    }
+};
+
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M16N16K8Tf32
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 16;
+    static constexpr index_t kN = 16;
+    static constexpr index_t kK = 8;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 16;
+    static constexpr index_t kBNLane     = 16;
+    static constexpr index_t kABKLane    = 4;
+    static constexpr index_t kABKPerLane = 2;
+
+    static constexpr index_t kCMLane     = 4;
+    static constexpr index_t kCNLane     = 16;
+    static constexpr index_t kCM0PerLane = 1;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+        DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x8_xf32", Ctrl)
+        else
+        {
+#if defined(__gfx942__)
+            c_vec = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_vec, b_vec, c_vec, 0, 0, 0);
+#else
+            ck_tile::ignore = c_vec;
+            ck_tile::ignore = a_vec;
+            ck_tile::ignore = b_vec;
+#endif
+        }
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+#if defined(__gfx942__)
+        return bit_cast(
+            __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_vec, b_vec, CVecType{0.f}, 0, 0, 0));
+#else
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+        return CVecType{0.f};
+#endif
+    }
+};
+
+// tf32/xf32 emulation on gfx950 using 3x bf16 MFMA
+// Algorithm: split float into bf16_big and bf16_small, then compute:
+//   out = A_big * B_big + A_small * B_big + A_big * B_small
+// This provides tf32-like precision using bf16 hardware
+
+// V_MFMA_F32_32x32x16_XF32 emulated on gfx950 using 3x bf16 32x32x16
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    // Input: 8 floats for K=16 (each lane holds 8 elements, kABKPerLane=8)
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 32;
+    static constexpr index_t kN = 32;
+    static constexpr index_t kK = 16;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 32;
+    static constexpr index_t kBNLane     = 32;
+    static constexpr index_t kABKLane    = 2;
+    static constexpr index_t kABKPerLane = 8;
+
+    static constexpr index_t kCMLane     = 2;
+    static constexpr index_t kCNLane     = 32;
+    static constexpr index_t kCM0PerLane = 4;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+#if defined(__gfx950__)
+        // Convert ext_vector to thread_buffer for element access
+        const auto& a_f32 = reinterpret_cast&>(a_vec);
+        const auto& b_f32 = reinterpret_cast&>(b_vec);
+
+        // Convert float to bf16 pairs
+        thread_buffer a_bf16_big, a_bf16_small, b_bf16_big, b_bf16_small;
+        convert_float_to_bf16_pairs(a_f32, a_bf16_big, a_bf16_small);
+        convert_float_to_bf16_pairs(b_f32, b_bf16_big, b_bf16_small);
+
+        // Get bf16x8 vectors for MFMA
+        auto a_big   = bit_cast>(a_bf16_big);
+        auto a_small = bit_cast>(a_bf16_small);
+        auto b_big   = bit_cast>(b_bf16_big);
+        auto b_small = bit_cast>(b_bf16_small);
+
+        // Run 3 bf16 MFMAs: small*big, big*small, big*big
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_small, b_big, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_small, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_big, c_vec, 0, 0, 0);
+#else
+        ck_tile::ignore = c_vec;
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+#endif
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+        CVecType c_vec{0.f};
+        (*this)(c_vec, a_vec, b_vec);
+        return c_vec;
+    }
+};
+
+// V_MFMA_F32_16x16x32_XF32 emulated on gfx950 using 3x bf16 16x16x32
+template 
+struct WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950
+{
+    static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
+
+    using ADataType = float;
+    using BDataType = float;
+    using CDataType = float;
+
+    // Input: 8 floats for K=32 (each lane holds 8 elements, kABKPerLane=8)
+    using AVecType = ext_vector_t;
+    using BVecType = ext_vector_t;
+    using CVecType = ext_vector_t;
+
+    static constexpr index_t kM = 16;
+    static constexpr index_t kN = 16;
+    static constexpr index_t kK = 32;
+
+    static constexpr index_t kAMBlock = 1;
+    static constexpr index_t kBNBlock = 1;
+
+    static constexpr index_t kAMLane     = 16;
+    static constexpr index_t kBNLane     = 16;
+    static constexpr index_t kABKLane    = 4;
+    static constexpr index_t kABKPerLane = 8;
+
+    static constexpr index_t kCMLane     = 4;
+    static constexpr index_t kCNLane     = 16;
+    static constexpr index_t kCM0PerLane = 1;
+    static constexpr index_t kCM1PerLane = 4;
+
+    // c_vec += a_vec * b_vec
+    template 
+    CK_TILE_DEVICE void operator()(CVecType& c_vec,
+                                   const AVecType& a_vec,
+                                   const BVecType& b_vec,
+                                   bool_constant = {}) const
+    {
+#if defined(__gfx950__)
+        // Convert ext_vector to thread_buffer for element access
+        const auto& a_f32 = reinterpret_cast&>(a_vec);
+        const auto& b_f32 = reinterpret_cast&>(b_vec);
+
+        // Convert float to bf16 pairs
+        thread_buffer a_bf16_big, a_bf16_small, b_bf16_big, b_bf16_small;
+        convert_float_to_bf16_pairs(a_f32, a_bf16_big, a_bf16_small);
+        convert_float_to_bf16_pairs(b_f32, b_bf16_big, b_bf16_small);
+
+        // Get bf16x8 vectors for MFMA
+        auto a_big   = bit_cast>(a_bf16_big);
+        auto a_small = bit_cast>(a_bf16_small);
+        auto b_big   = bit_cast>(b_bf16_big);
+        auto b_small = bit_cast>(b_bf16_small);
+
+        // Run 3 bf16 MFMAs: small*big, big*small, big*big
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_small, b_big, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_small, c_vec, 0, 0, 0);
+        c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_big, c_vec, 0, 0, 0);
+#else
+        ck_tile::ignore = c_vec;
+        ck_tile::ignore = a_vec;
+        ck_tile::ignore = b_vec;
+#endif
+    }
+
+    // c_vec = a_vec * b_vec
+    CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
+    {
+        CVecType c_vec{0.f};
+        (*this)(c_vec, a_vec, b_vec);
+        return c_vec;
+    }
+};
+
 // V_MFMA_F32_16x16x32_BF16
 template 
 struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32
diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
index d6c21e88b56..b1ad89cc9fd 100644
--- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
@@ -35,6 +35,27 @@ struct Dispatcher;
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
 template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution; };
+// tf32
+// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
+#if defined(CK_GFX950_SUPPORT)
+// For tf32 on gfx950: epilogue uses float types but 32x32x16 tile
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+// On gfx950, tf32 uses bf16 emulation
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M16N16K16; };
+template<> struct Dispatcher { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16; };
+#endif
+// Note: For gfx11/gfx12 and other architectures that don't support tf32,
+// these dispatchers are not defined. Code using tf32 should be guarded
+// by CK_ENABLE_TF32 or CK_GFX950_SUPPORT macros.
 // fp16
 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
 template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
index 899221547f6..a1b43460c17 100644
--- a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
+++ b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
@@ -20,6 +20,12 @@ struct DataTypeTraits
     static constexpr const char* name = "fp32";
 };
 
+template <>
+struct DataTypeTraits
+{
+    static constexpr const char* name = "tf32";
+};
+
 template <>
 struct DataTypeTraits
 {