From 81187b7bc9915ae867175c684f126709e99b7421 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 21 Apr 2026 18:21:12 -0700 Subject: [PATCH] [CUDA] Separate main loop into a function in qmm --- mlx/backend/cuda/quantized/qmm/qmm_naive.cu | 419 +++---------------- mlx/backend/cuda/quantized/qmm/qmm_naive.cuh | 381 +++++++++++++++++ mlx/backend/cuda/quantized/qmm/qmm_sm80.cu | 391 +++-------------- mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh | 346 +++++++++++++++ 4 files changed, 842 insertions(+), 695 deletions(-) create mode 100644 mlx/backend/cuda/quantized/qmm/qmm_naive.cuh create mode 100644 mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu index 5be75bd09f..0723b94930 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu @@ -1,9 +1,8 @@ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" -#include "mlx/dtype_utils.h" +#include "mlx/backend/cuda/quantized/qmm/qmm_naive.cuh" // clang-format off @@ -12,49 +11,26 @@ namespace cutlass_gemm { using namespace cute; -template -struct SharedStorage { - ArrayEngine> A; - ArrayEngine> B; -}; - -__device__ __forceinline__ void -cute_naive_dequant(auto w, auto s, auto z, auto out) { - using Element = typename decltype(out)::value_type; - using Quant = typename decltype(w)::value_type; - using Scale = typename decltype(s)::value_type; - transform(w, out, [](Quant q) { return Element(q); } ); - transform(out, s, out, [](Element e, Scale s) { return e * Element(s); }); - if constexpr (quant_has_bias_v) { - transform(out, z, out, plus{}); - } -} - -__device__ __forceinline__ void -cute_dequant(auto w, auto s, auto z, auto out) { - if constexpr (stride(coalesce(w.layout())) == Int<1>{} && - is_static_v) { - cute_vectorized_dequant(w, s, z, out); - } else { - cute_naive_dequant(w, s, z, out); - } -} - -template -__global__ void qmm_naive_kernel( - ProblemShape shape_MNKL, CtaTiler cta_tiler, - const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA copy_a, - const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB copy_b, - Element* C, StrideC dC, + typename ProblemShape, + typename CtaTiler, + typename StrideA, + typename StrideB, + typename LayoutS, + typename StrideC, + typename TiledMma> +__global__ +__launch_bounds__(decltype(size(TiledMma{}))::value) +void qmm_naive_kernel( + ProblemShape shape_MNKL, + CtaTiler cta_tiler, + const Element* A, StrideA dA, + const Quant* B, StrideB dB, const Scale* S, const Element* Z, LayoutS S_layout, const uint32_t* lhs_indices, const uint32_t* rhs_indices, + Element* C, StrideC dC, TiledMma mma) { - CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); - CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); @@ -62,20 +38,6 @@ __global__ void qmm_naive_kernel( int thread_idx = int(threadIdx.x); auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); - auto m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord - auto n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord - - // Shift tensor so we handle residue of K in the 0th tile. - auto shape_K = size<2>(shape_MNKL); - auto bK = size<2>(cta_tiler); - auto k_residue = shape_K - bK * ceil_div(shape_K, bK); - if constexpr (HasKResidue) { - A += k_residue * get<1>(dA); - B += k_residue * get<1>(dB) * cuda::std::min(8, sizeof_bits_v) / 8; - S += k_residue * stride<1>(S_layout); - Z += k_residue * stride<1>(S_layout); - } - // Represent the full tensors. Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) @@ -105,218 +67,24 @@ __global__ void qmm_naive_kernel( Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - // Shared memory buffers. - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& smem = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K) - Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) - - // Partition the copying of A/B/C tiles across the threads. - ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); - Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) - Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) - Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) - - ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); - Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) - Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) - Tensor tBrB_dq = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) - Tensor tBgS = thr_copy_b.partition_S(gS); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBrS = make_fragment_like(tBgS(_,_,_,0)); // (BCPY,BCPY_N,BCPY_K) - Tensor tBgZ = thr_copy_b.partition_S(gZ); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBrZ = make_fragment_like(tBgZ(_,_,_,0)); // (BCPY,BCPY_N,BCPY_K) - - // MMA. - ThrMMA thr_mma = mma.get_slice(thread_idx); - Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) - Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) - Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) - Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) - - // Predicates for m/n bounds. - Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) - Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); // (CPY_N,CPY_K) - Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) - Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) - Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); // (M,N) - Tensor tAcA = thr_copy_a.partition_S(cA); // (CPY,CPY_M,CPY_K) - Tensor tBcB = thr_copy_b.partition_S(cB); // (CPY,CPY_N,CPY_K) - Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) - CUTE_UNROLL - for (int m = 0; m < size<0>(tApA); ++m) { - tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; - } - CUTE_UNROLL - for (int n = 0; n < size<0>(tBpB); ++n) { - tBpB(n,0) = get<0>(tBcB(0,n,0)) < n_max_coord; - } - - // GMEM => RMEM. - auto fetch_gmem = [&](int tile) { - copy_if(copy_a, tApA, tAgA(_,_,_,tile), tArA); - copy_if(copy_b, tBpB, tBgB(_,_,_,tile), tBrB); - copy(tBgS(_,_,_,tile), tBrS); - copy(tBgZ(_,_,_,tile), tBrZ); - }; - // RMEM => SMEM. - auto store_smem = [&]() { - __syncthreads(); - copy(tArA, tAsA); - CUTE_UNROLL - for (int k = 0; k < size<2>(tBrB); ++k) { - CUTE_UNROLL - for (int n = 0; n < size<1>(tBrB); ++n) { - cute_dequant(tBrB(_,n,k), tBrS(_,n,k), tBrZ(_,n,k), tBrB_dq(_,n,k)); - } - } - copy(tBrB_dq, tBsB); - __syncthreads(); - }; - - // Clear the rmem tiles to account for predicated off loads. - if constexpr (HasKResidue) { - clear(tArA); - clear(tBrB); - clear(tBrS); - clear(tBrZ); - } - - // Prefetch first tile. - if constexpr (HasKResidue) { - Tensor tAgA_k = tAgA(_,_,_,0); - CUTE_UNROLL - for (int k = 0; k < size<2>(tArA); ++k) { - if (get<1>(tAcA(0,0,k)) >= -k_residue) { - copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k)); - } - } - Tensor tBgB_k = tBgB(_,_,_,0); - Tensor tBgS_k = tBgS(_,_,_,0); - Tensor tBgZ_k = tBgZ(_,_,_,0); - CUTE_UNROLL - for (int k = 0; k < size<2>(tBrB); ++k) { - if (get<1>(tBcB(0,0,k)) >= -k_residue) { - copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k)); - copy(tBgS_k(_,_,k), tBrS(_,_,k)); - copy(tBgZ_k(_,_,k), tBrZ(_,_,k)); - } - } - } else { - fetch_gmem(0); - } - - // Clear accumulators. - clear(tCrC); - - // Loop over CTA tiles. - auto K_TILE_MAX = size<3>(tAgA); - for (int tile = 0; tile < K_TILE_MAX; ++tile) { - store_smem(); - if constexpr (HasKResidue) { - // Avoid fetching full 0th-tile when there is residue. - if (K_TILE_MAX > 1) { - fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); - } - } else { - fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); - } - gemm(mma, tCsA, tCsB, tCrC); - } - - // Epilogue. - CUTE_UNROLL - for (int i = 0; i < size(tCrC); ++i) { - if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { - tCgC(i) = Element(tCrC(i)); - } - } -} - -template -inline constexpr auto make_matrix_stride(auto m, auto k) { - if constexpr (KMajor) { - return cute::make_stride(k, cute::Int<1>{}, m * k); - } else { - return cute::make_stride(cute::Int<1>{}, m, m * k); - } -} - -template -inline constexpr auto make_smem_layout(auto bM, auto bK) { - // TODO: Calculate swizzle based on tile shape. - if constexpr (KMajor) { - auto swizzle = composition(Swizzle<3,3,3>{}, - Layout>, - Stride<_8,Stride<_1,_64>>>{}); - return tile_to_shape(swizzle, make_shape(bM, bK)); - } else { - auto swizzle = composition(Swizzle<3,3,3>{}, - Layout, Stride<_1,_64>>{}); - return tile_to_shape(swizzle, make_shape(bM, bK)); - } -} - -template -inline constexpr auto make_tiled_mma() { - using Atom = std::conditional_t< - SM80, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32F16F16F32_TN, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32BF16BF16F32_TN, - UniversalFMA - > - >, - UniversalFMA>; - if constexpr (!SM80 || std::is_same_v) { - return make_tiled_mma(Atom{}, Layout>{}); - } else { - if constexpr (TileM >= 32) { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); - } else { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); - } - } -} - -template -inline auto make_tiled_copy(auto num_threads, auto bM, auto bK) { - // TODO: Only do 1-element read for the tile of residue. - auto n_read = Int{}; - auto atom = Copy_Atom>>, T>{}; - if constexpr (KMajor) { - auto k_threads = bK / n_read; - return make_tiled_copy( - atom, - make_layout(make_shape(Int{}, k_threads), LayoutRight{}), - make_layout(make_shape(Int<1>{}, n_read))); - } else { - auto m_threads = bM / n_read; - return make_tiled_copy( - atom, - make_layout(make_shape(m_threads, Int{}), LayoutLeft{}), - make_layout(make_shape(n_read, Int<1>{}))); - } + // Compute tile residues for predication. + int m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord + int n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord + int k_residue = size<2>(shape_MNKL) - size<1>(gA) * size<2>(gA); + + qmm_naive_mainloop( + cta_tiler, + gA, + gB, + gS, + gZ, + gC, + mma, + m_max_coord, n_max_coord, k_residue, + thread_idx); } -template -inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size) { - if constexpr (KMajor) { - return make_layout( - make_shape(n, make_shape(group_size, k / group_size), l), - make_stride(k / group_size, Stride<_0,_1>{}, n * k / group_size)); - } else { - return make_layout( - make_shape(make_shape(group_size, n / group_size), k, l), - make_stride(Stride<_0,_1>{}, n / group_size, n * k / group_size)); - } -} - -template void qmm_naive( const Element* A, @@ -331,14 +99,12 @@ void qmm_naive( auto group_size, auto&& launch_kernel) { // Define shapes (dynamic). - auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) + auto shape_MNKL = make_shape(m, n, k, l); // (M,N,K,L) - // Define TN strides (mixed). + // Define layouts (mixed). auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) auto dB = make_matrix_stride(n, k); // (dN,dK,dL) auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) - - // Define layout of scales/biases (mixed). auto S_layout = make_scales_layout(n, k, l, group_size); // Handle broadcasting. @@ -347,45 +113,41 @@ void qmm_naive( get<2>(stride(S_layout)) = 0; } - // Define CTA tile sizes (static). - auto bM = Int{}; - auto bN = Int<(!SM80 && group_size > 64) ? 64 : 128>{}; - auto bK = Int{}; - auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K) + // Define CTA tile size (static). + auto cta_tiler = make_cta_tiler(group_size); // Define MMA. - TiledMMA mma = make_tiled_mma(); + auto mma = make_tiled_mma(cta_tiler); auto num_threads = size(mma); - // Define the A/B smem layouts (static). - auto sA_layout = make_smem_layout(bM, bK); - auto sB_layout = make_smem_layout(bN, bK); - - // Atoms. - TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); - TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); + // Shared memory size. + auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); + size_t smem_bytes = sizeof(SharedStorage); auto* kernel = &qmm_naive_kernel< - HasKResidue, decltype(prob_shape), decltype(cta_tiler), + KMajor, HasKResidue, SM80, Element, Quant, Scale, - decltype(dA), decltype(sA_layout), decltype(copy_a), - decltype(dB), decltype(sB_layout), decltype(copy_b), - decltype(dC), decltype(S_layout), decltype(mma)>; - - // Set L1 to be SMEM only. - size_t smem_bytes = sizeof(SharedStorage); + decltype(shape_MNKL), + decltype(cta_tiler), + decltype(dA), + decltype(dB), + decltype(S_layout), + decltype(dC), + decltype(mma)>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l); - dim3 block_dims(num_threads); + dim3 num_blocks{uint32_t(ceil_div(m, size<0>(cta_tiler))), + uint32_t(ceil_div(n, size<1>(cta_tiler))), + uint32_t(l)}; + dim3 block_dims{num_threads}; void* args[] = { - &prob_shape, &cta_tiler, - &A, &dA, &sA_layout, ©_a, - &B, &dB, &sB_layout, ©_b, - &C, &dC, + &shape_MNKL, + &cta_tiler, + &A, &dA, + &B, &dB, &S, &Z, &S_layout, &lhs_indices, &rhs_indices, + &C, &dC, &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -396,69 +158,6 @@ void qmm_naive( namespace mlx::core { -template -inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { - if (dtype == float32) { - f.template operator()(); - } else if (dtype == float16) { - f.template operator()(); - } else if (dtype == bfloat16) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); - } -} - -template -inline void dispatch_groups(int group_size, const char* tag, F&& f) { - if (group_size == 32) { - f.template operator()<32>(); - } else if (group_size == 64) { - f.template operator()<64>(); - } else if (group_size == 128) { - f.template operator()<128>(); - } else { - throw std::invalid_argument( - fmt::format("{} Group size {} is not supported.", tag, group_size)); - } -} - -template -inline void dispatch_quant_types( - int bits, - int group_size, - QuantizationMode mode, - const char* tag, - F&& f) { - if (mode == QuantizationMode::Mxfp4) { - f.template operator()(); - } else if (mode == QuantizationMode::Mxfp8) { - f.template operator()(); - } else if (mode == QuantizationMode::Nvfp4) { - f.template operator()(); - } else { - dispatch_groups(group_size, tag, [&]() { - if (bits == 2) { - f.template operator()(); - } else if (bits == 3) { - f.template operator()(); - } else if (bits == 4) { - f.template operator()(); - } else if (bits == 5) { - f.template operator()(); - } else if (bits == 6) { - f.template operator()(); - } else if (bits == 8) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} {}-bit quantization is not supported.", tag, bits)); - } - }); - } -} - template void qmm_naive_impl( const array& x, @@ -516,7 +215,7 @@ void qmm_naive_impl( [&](auto* kernel, dim3 num_blocks, dim3 block_dims, - uint32_t smem_bytes, + size_t smem_bytes, void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_naive.cuh new file mode 100644 index 0000000000..d27097bcfc --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_naive.cuh @@ -0,0 +1,381 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" +#include "mlx/dtype_utils.h" + +// clang-format off + +// We can't put kernel code in mlx::core due to name conflicts of "Shape". +namespace cutlass_gemm { + +using namespace cute; + +template +struct SharedStorage { + ArrayEngine> A; + ArrayEngine> B; +}; + +template +inline constexpr auto make_smem_layout(auto bM, auto bK) { + // TODO: Calculate swizzle based on tile shape. + if constexpr (KMajor) { + auto swizzle = composition(Swizzle<3,3,3>{}, + Layout>, + Stride<_8,Stride<_1,_64>>>{}); + return tile_to_shape(swizzle, make_shape(bM, bK)); + } else { + auto swizzle = composition(Swizzle<3,3,3>{}, + Layout, Stride<_1,_64>>{}); + return tile_to_shape(swizzle, make_shape(bM, bK)); + } +} + +template +inline constexpr auto make_smem_layouts(auto cta_tiler) { + auto [bM, bN, bK] = cta_tiler; + auto sA_layout = make_smem_layout(bM, bK); + auto sB_layout = make_smem_layout(bN, bK); + return std::make_tuple(sA_layout, sB_layout); +} + +template +inline constexpr auto make_tiled_copy(auto num_threads, auto bM, auto bK) { + // TODO: Only do 1-element read for the tile of residue. + auto n_read = Int{}; + auto atom = Copy_Atom>>, T>{}; + if constexpr (KMajor) { + auto k_threads = bK / n_read; + return make_tiled_copy( + atom, + make_layout(make_shape(Int{}, k_threads), LayoutRight{}), + make_layout(make_shape(Int<1>{}, n_read))); + } else { + auto m_threads = bM / n_read; + return make_tiled_copy( + atom, + make_layout(make_shape(m_threads, Int{}), LayoutLeft{}), + make_layout(make_shape(n_read, Int<1>{}))); + } +} + + +__device__ __forceinline__ void +cute_naive_dequant(auto w, auto s, auto z, auto out) { + using Element = typename decltype(out)::value_type; + using Quant = typename decltype(w)::value_type; + using Scale = typename decltype(s)::value_type; + transform(w, out, [](Quant q) { return Element(q); } ); + transform(out, s, out, [](Element e, Scale s) { return e * Element(s); }); + if constexpr (quant_has_bias_v) { + transform(out, z, out, plus{}); + } +} + +__device__ __forceinline__ void +cute_dequant(auto w, auto s, auto z, auto out) { + if constexpr (stride(coalesce(w.layout())) == Int<1>{} && + is_static_v) { + cute_vectorized_dequant(w, s, z, out); + } else { + cute_naive_dequant(w, s, z, out); + } +} + +template +CUTE_DEVICE void qmm_naive_mainloop( + CtaTiler cta_tiler, + TensorA gA, + TensorB gB, + TensorS gS, + TensorZ gZ, + TensorC gC, + TiledMma mma, + int m_max_coord, + int n_max_coord, + int k_residue, + int thread_idx) { + // Get the types of operands. + using Element = decltype(gA)::value_type; + using Quant = decltype(gB)::value_type; + + // Shift tensor so we handle residue of K in the 0th tile. + gA = domain_offset(make_coord(0, k_residue, 0), gA); + if constexpr (sizeof_bits_v % 8 == 0) { + gB = domain_offset(make_coord(0, k_residue, 0), gB); + } else { + gB.data() = recast_ptr(raw_pointer_cast(gB.data()) + gB.layout()(0, k_residue, 0) * cuda::std::min(8, sizeof_bits_v) / 8); + } + gS = domain_offset(make_coord(0, k_residue, 0), gS); + gZ = domain_offset(make_coord(0, k_residue, 0), gZ); + + // Define smem layouts. + auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); + + // Shared memory buffer. + extern __shared__ char smem_buf[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) + + // Define copy atoms. + auto num_threads = size(mma); + auto [bM, bN, bK] = cta_tiler; + TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); + TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); + + // Partition the copying of A/B/C tiles across the threads. + ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); + Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + + ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); + Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) + Tensor tBrB_dq = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) + Tensor tBgS = thr_copy_b.partition_S(gS); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBrS = make_fragment_like(tBgS(_,_,_,0)); // (BCPY,BCPY_N,BCPY_K) + Tensor tBgZ = thr_copy_b.partition_S(gZ); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBrZ = make_fragment_like(tBgZ(_,_,_,0)); // (BCPY,BCPY_N,BCPY_K) + + // MMA. + ThrMMA thr_mma = mma.get_slice(thread_idx); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + // Predicates for m/n bounds. + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); // (CPY_N,CPY_K) + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) + Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); // (M,N) + Tensor tAcA = thr_copy_a.partition_S(cA); // (CPY,CPY_M,CPY_K) + Tensor tBcB = thr_copy_b.partition_S(cB); // (CPY,CPY_N,CPY_K) + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) + CUTE_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; + } + CUTE_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < n_max_coord; + } + + // GMEM => RMEM. + auto fetch_gmem = [&](int tile) { + copy_if(copy_a, tApA, tAgA(_,_,_,tile), tArA); + copy_if(copy_b, tBpB, tBgB(_,_,_,tile), tBrB); + copy(tBgS(_,_,_,tile), tBrS); + copy(tBgZ(_,_,_,tile), tBrZ); + }; + // RMEM => SMEM. + auto store_smem = [&]() { + __syncthreads(); + copy(tArA, tAsA); + CUTE_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + CUTE_UNROLL + for (int n = 0; n < size<1>(tBrB); ++n) { + cute_dequant(tBrB(_,n,k), tBrS(_,n,k), tBrZ(_,n,k), tBrB_dq(_,n,k)); + } + } + copy(tBrB_dq, tBsB); + __syncthreads(); + }; + + // Clear the rmem tiles to account for predicated off loads. + if constexpr (HasKResidue) { + clear(tArA); + clear(tBrB); + clear(tBrS); + clear(tBrZ); + } + + // Prefetch first tile. + if constexpr (HasKResidue) { + Tensor tAgA_k = tAgA(_,_,_,0); + CUTE_UNROLL + for (int k = 0; k < size<2>(tArA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -k_residue) { + copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k)); + } + } + Tensor tBgB_k = tBgB(_,_,_,0); + Tensor tBgS_k = tBgS(_,_,_,0); + Tensor tBgZ_k = tBgZ(_,_,_,0); + CUTE_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -k_residue) { + copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k)); + copy(tBgS_k(_,_,k), tBrS(_,_,k)); + copy(tBgZ_k(_,_,k), tBrZ(_,_,k)); + } + } + } else { + fetch_gmem(0); + } + + // Clear accumulators. + clear(tCrC); + + // Loop over CTA tiles. + auto K_TILE_MAX = size<3>(tAgA); + for (int tile = 0; tile < K_TILE_MAX; ++tile) { + store_smem(); + if constexpr (HasKResidue) { + // Avoid fetching full 0th-tile when there is residue. + if (K_TILE_MAX > 1) { + fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + } + } else { + fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + } + gemm(mma, tCsA, tCsB, tCrC); + } + + // Epilogue. + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) { + if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { + tCgC(i) = Element(tCrC(i)); + } + } +} + +template +inline constexpr auto make_matrix_stride(auto m, auto k) { + if constexpr (KMajor) { + return cute::make_stride(k, cute::Int<1>{}, m * k); + } else { + return cute::make_stride(cute::Int<1>{}, m, m * k); + } +} + +template +inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size) { + if constexpr (KMajor) { + return make_layout( + make_shape(n, make_shape(group_size, k / group_size), l), + make_stride(k / group_size, Stride<_0,_1>{}, n * k / group_size)); + } else { + return make_layout( + make_shape(make_shape(group_size, n / group_size), k, l), + make_stride(Stride<_0,_1>{}, n / group_size, n * k / group_size)); + } +} + +template +inline constexpr auto make_cta_tiler(auto group_size) { + auto bM = Int{}; + auto bN = Int<(!SM80 && group_size > 64) ? 64 : 128>{}; + auto bK = Int{}; + return make_shape(bM, bN, bK); +} + +template +inline constexpr auto make_tiled_mma(auto cta_tiler) { + using Atom = std::conditional_t< + SM80, + std::conditional_t< + std::is_same_v, + SM80_16x8x16_F32F16F16F32_TN, + std::conditional_t< + std::is_same_v, + SM80_16x8x16_F32BF16BF16F32_TN, + UniversalFMA + > + >, + UniversalFMA>; + if constexpr (!SM80 || std::is_same_v) { + return make_tiled_mma(Atom{}, Layout>{}); + } else { + if constexpr (size<0>(cta_tiler) >= 32) { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + } else { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + } + } +} + +} // namespace cutlass_gemm + +// clang-format on + +namespace mlx::core { + +template +inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { + if (dtype == float32) { + f.template operator()(); + } else if (dtype == float16) { + f.template operator()(); + } else if (dtype == bfloat16) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); + } +} + +template +inline void dispatch_groups(int group_size, const char* tag, F&& f) { + if (group_size == 32) { + f.template operator()<32>(); + } else if (group_size == 64) { + f.template operator()<64>(); + } else if (group_size == 128) { + f.template operator()<128>(); + } else { + throw std::invalid_argument( + fmt::format("{} Group size {} is not supported.", tag, group_size)); + } +} + +template +inline void dispatch_quant_types( + int bits, + int group_size, + QuantizationMode mode, + const char* tag, + F&& f) { + if (mode == QuantizationMode::Mxfp4) { + f.template operator()(); + } else if (mode == QuantizationMode::Mxfp8) { + f.template operator()(); + } else if (mode == QuantizationMode::Nvfp4) { + f.template operator()(); + } else { + dispatch_groups(group_size, tag, [&]() { + if (bits == 2) { + f.template operator()(); + } else if (bits == 3) { + f.template operator()(); + } else if (bits == 4) { + f.template operator()(); + } else if (bits == 5) { + f.template operator()(); + } else if (bits == 6) { + f.template operator()(); + } else if (bits == 8) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} {}-bit quantization is not supported.", tag, bits)); + } + }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu index 028f18cd52..b1fa28dba1 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu @@ -1,8 +1,7 @@ // Copyright © 2026 Apple Inc. -#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" -#include "mlx/dtype_utils.h" +#include "mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh" // clang-format off @@ -11,38 +10,24 @@ namespace cutlass_gemm { using namespace cute; -template -union SharedStorage { - struct { - ArrayEngine> A; - ArrayEngine> B; - } mainloop; - struct { - ArrayEngine> C; - } epilogue; -}; - -template -__global__ void qmm_sm80_kernel( +template +__global__ +__launch_bounds__(decltype(size(TiledMma{}))::value) +void qmm_sm80_kernel( ProblemShape shape_MNKL, CtaTiler cta_tiler, - const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a, - const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b, - Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c, - const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, + const Element* A, StrideA dA, + const Quant* B, StrideB dB, + const Scale* S, const Element* Z, LayoutS S_layout, const uint32_t* lhs_indices, const uint32_t* rhs_indices, + Element* C, StrideC dC, TiledMma mma) { - CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma)); - CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma)); - CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma)); CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); @@ -79,201 +64,23 @@ __global__ void qmm_sm80_kernel( Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - // Shared memory buffers. - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& smem = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor(make_smem_ptr(smem.mainloop.A.begin()), sA_layout); // (BLK_M,BLK_K) - Tensor sB = make_tensor(make_smem_ptr(smem.mainloop.B.begin()), sB_layout); // (BLK_N,BLK_K) - Tensor sC = make_tensor(make_smem_ptr(smem.epilogue.C.begin()), sC_layout); // (BLK_M,BLK_N) - - // Partition the copying of A/B/C tiles across the threads. - ThrCopy g2s_thr_copy_a = g2s_copy_a.get_slice(thread_idx); - Tensor tAgA = g2s_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) - Tensor tAsA = g2s_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) - - ThrCopy g2s_thr_copy_b = g2s_copy_b.get_slice(thread_idx); - Tensor tBgB = g2s_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBsB = g2s_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) - - ThrCopy s2g_thr_copy_c = s2g_copy_c.get_slice(thread_idx); - Tensor s2g_tCsC = s2g_thr_copy_c.partition_S(sC); // (CCPY,CCPY_M,CCPY_N) - Tensor s2g_tCgC = s2g_thr_copy_c.partition_D(gC); // (CCPY,CCPY_M,CCPY_N) - - // MMA. - ThrMMA thr_mma = mma.get_slice(thread_idx); - Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) - Tensor tCsB = thr_mma.partition_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) - Tensor tCrB = make_fragment_like(tCsB); // (MMA,MMA_N,MMA_K) - Tensor tCrB_dq = make_fragment_like(tCsB); // (MMA,MMA_N,MMA_K) - Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) - Tensor tCrC_accu = make_fragment_like(tCgC); // (MMA,MMA_M,MMA_N) - Tensor tCrC = make_fragment_like(tCgC); // (MMA,MMA_M,MMA_N) - - Tensor tCgS = thr_mma.partition_B(gS); // (MMA,MMA_N,MMA_K,k) - Tensor tCrS = make_tensor_like(tCgS(_,_,_,0)); // (MMA,MMA_N,MMA_K) - Tensor tCgZ = thr_mma.partition_B(gZ); // (MMA,MMA_N,MMA_K,k) - Tensor tCrZ = make_tensor_like(tCgZ(_,_,_,0)); // (MMA,MMA_N,MMA_K) - - // Copy Atom retiling. - TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); - ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(thread_idx); - Tensor s2r_tCsA = s2r_thr_copy_a.partition_S(sA); // (ACPY,MMA_M,MMA_K,PIPE) - Tensor s2r_tCrA = s2r_thr_copy_a.retile_D(tCrA); // (ACPY,MMA_M,MMA_K) - - TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); - ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(thread_idx); - Tensor s2r_tCsB = s2r_thr_copy_b.partition_S(sB); // (BCPY,MMA_N,MMA_K,PIPE) - Tensor s2r_tCrB = s2r_thr_copy_b.retile_D(tCrB); // (BCPY,MMA_N,MMA_K) - - TiledCopy r2s_copy_c = make_tiled_copy_C(r2s_atom_c, mma); - ThrCopy r2s_thr_copy_c = r2s_copy_c.get_slice(thread_idx); - Tensor r2s_tCrC = r2s_thr_copy_c.retile_S(tCrC); // (CCPY,MMA_M,MMA_N) - Tensor r2s_tCsC = r2s_thr_copy_c.partition_D(sC); // (CCPY,MMA_M,MMA_N) - - TiledCopy g2r_copy_s = make_tiled_copy_B(g2r_atom_s, mma); - ThrCopy g2r_thr_copy_s = g2r_copy_s.get_slice(thread_idx); - Tensor g2r_tCgS = g2r_thr_copy_s.partition_S(gS); // (BCPY,MMA_N,MMA_K,k) - Tensor g2r_tCrS = g2r_thr_copy_s.retile_D(tCrS); // (BCPY,MMA_N,MMA_K) - Tensor g2r_tCgZ = g2r_thr_copy_s.partition_S(gZ); // (BCPY,MMA_N,MMA_K,k) - Tensor g2r_tCrZ = g2r_thr_copy_s.retile_D(tCrZ); // (BCPY,MMA_N,MMA_K) - - // Predicates for m bound. + // Compute tile residues for predication. auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord - Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) - Tensor tCpC = make_tensor(make_shape(size<1>(s2g_tCsC), size<2>(s2g_tCsC)), Stride<_1,_0>{}); // (CPY_M,CPY_N) - Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) - Tensor cC = make_identity_tensor(make_shape(size<0>(sC), size<1>(sC))); // (BLK_M,BLK_N) - Tensor tAcA = g2s_thr_copy_a.partition_D(cA); // (CPY,CPY_M,CPY_K) - Tensor tCcC = s2g_thr_copy_c.partition_D(cC); // (CPY,CPY_M,CPY_N) - CUTE_UNROLL - for (int m = 0; m < size<0>(tApA); ++m) { - tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; - } - CUTE_UNROLL - for (int m = 0; m < size<0>(tCpC); ++m) { - tCpC(m,0) = get<0>(tCcC(0,m,0)) < m_max_coord; - } - - auto K_PIPE_MAX = size<3>(tAsA); - int smem_pipe_read = 0; - int smem_pipe_write = 0; - - // Copy A/B: GMEM => SMEM. - auto fetch_gmem = [&](int tile) { - copy_if(g2s_copy_a, tApA, tAgA(_,_,_,tile), tAsA(_,_,_,smem_pipe_write)); - copy(g2s_copy_b, tBgB(_,_,_,tile), tBsB(_,_,_,smem_pipe_write)); - cp_async_fence(); - smem_pipe_write = (smem_pipe_write + 1) % K_PIPE_MAX; - }; - // Copy S/Z: GMEM => RMEM. - auto fetch_scales = [&](int tile) { - copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS); - if constexpr (quant_has_bias_v) { - copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); - } - }; - // Copy A/B: SMEM => RMEM. - auto fetch_smem = [&](auto block) { - copy(s2r_atom_a, s2r_tCsA(_,_,block,smem_pipe_read), s2r_tCrA(_,_,block)); - copy(s2r_atom_b, s2r_tCsB(_,_,block,smem_pipe_read), s2r_tCrB(_,_,block)); - CUTE_UNROLL - for (int n = 0; n < size<1>(tCrB); ++n) { - cute_vectorized_dequant( - tCrB(_,n,block), - tCrS(_,n,block), - tCrZ(_,n,block), - tCrB_dq(_,n,block)); - } - }; - auto K_TILE_MAX = size<3>(tAgA); - auto K_BLOCK_MAX = size<2>(tCrA); - - // Prefetch beginning tiles. - int tile_pipe = 0; - CUTE_UNROLL - for (; tile_pipe < K_PIPE_MAX - 1; ++tile_pipe) { - fetch_gmem(tile_pipe); - } - - // Clear accumulators. - clear(tCrC_accu); - - // Prefetch first block. - if constexpr (K_BLOCK_MAX > 1) { - cp_async_wait(); - __syncthreads(); - fetch_scales(0); - fetch_smem(Int<0>{}); - } - - // Loop over CTA tiles. - for (int tile = 0; tile < K_TILE_MAX; ++tile) { - // Unroll MMA blocks. - CUTE_UNROLL - for (int block = 0; block < K_BLOCK_MAX; ++block) { - // Wait for last tile. - if (block == K_BLOCK_MAX - 1) { - smem_pipe_read = (smem_pipe_read + 1) % K_PIPE_MAX; - cp_async_wait(); - __syncthreads(); - fetch_scales((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); - } - // Prefetch next block. - fetch_smem((block + 1) % K_BLOCK_MAX); - // Prefetch next tile. - if (block == 0) { - fetch_gmem(tile_pipe); - tile_pipe = (tile_pipe + 1 < K_TILE_MAX) ? tile_pipe + 1 : tile_pipe; - } - // MMA. - gemm(mma, tCrA(_,_,block), tCrB_dq(_,_,block), tCrC_accu); - } - } - - // Epilogue. - CUTE_UNROLL - for (int i = 0; i < size(tCrC_accu); i++) { - tCrC(i) = Element(tCrC_accu(i)); - } - copy(r2s_copy_c, r2s_tCrC, r2s_tCsC); - __syncthreads(); - copy_if(s2g_copy_c, tCpC, s2g_tCsC, s2g_tCgC); + qmm_sm80_mainloop( + cta_tiler, + gA, + gB, + gS, + gZ, + gC, + mma, + m_max_coord, + thread_idx); } -template -inline constexpr auto make_mma_atom() { - if constexpr (std::is_same_v) { - return SM80_16x8x16_F32F16F16F32_TN{}; - } - if constexpr (std::is_same_v) { - return SM80_16x8x16_F32BF16BF16F32_TN{}; - } -} - -template -inline constexpr auto make_tiled_mma() { - constexpr auto atom = make_mma_atom(); - if constexpr (TileM >= 32) { - return make_tiled_mma(atom, Layout>{}, Tile<_32,_32,_16>{}); - } else { - return make_tiled_mma(atom, Layout>{}, Tile<_16,_32,_16>{}); - } -} - -template typename Atom, typename NumThreads> -inline auto make_tiled_copy(NumThreads num_threads) { - return make_tiled_copy( - Copy_Atom>, T>{}, - make_layout(make_shape(Int{}, Int<8>{}), LayoutRight{}), - make_layout(make_shape(Int<1>{}, Int>{}))); -} - -template +template void qmm_sm80( const Element* A, const Quant* B, @@ -284,20 +91,16 @@ void qmm_sm80( Element* C, int m, int n, int k, int l, bool broadcast_b, - GroupSize group_size, + auto group_size, auto&& launch_kernel) { // Define shapes (dynamic). - auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) + auto shape_MNKL = make_shape(m, n, k, l); // (M,N,K,L) - // Define TN strides (mixed). + // Define layouts (mixed). auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL) auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) - - // Define layout of scales/biases (mixed). - auto S_layout = make_layout( - make_shape(n, make_shape(group_size, k / group_size), l), - make_stride(k / group_size, Stride<_0, _1>{}, n * k / group_size)); + auto S_layout = make_scales_layout(n, k, l, group_size); // Handle broadcasting. if (broadcast_b) { @@ -306,70 +109,41 @@ void qmm_sm80( } // Define CTA tile sizes (static). - auto bM = Int{}; - auto bN = Int<128>{}; - auto bK = Int{}; - auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K) + auto cta_tiler = make_cta_tiler(group_size); // Define MMA. TiledMMA mma = make_tiled_mma(); auto num_threads = size(mma); - // Define the A/B smem layouts (static). - auto swizzle_ab = composition(Swizzle<3,3,3>{}, - Layout>, - Stride<_8,Stride<_1,_64>>>{}); - auto bP = Int<3>{}; // pipeline - auto sA_layout = tile_to_shape(swizzle_ab, make_shape(bM, bK, bP)); - auto sB_layout = tile_to_shape(swizzle_ab, make_shape(bN, bK, bP)); - - // Define the C smem layouts (static). - // TODO: Find a better swizzle. - auto sC_layout = tile_to_shape(swizzle_ab, make_shape(bM, bN)); - - // Define the scales/biases smem layouts (static). - auto bS = ceil_div(bK, group_size); - auto sS_layout = make_layout(make_shape(bN, make_shape(group_size, bS)), - make_stride(bS, Stride<_0, _1>{})); - - // Atoms. - constexpr int element_bits = sizeof_bits_v; - constexpr int quant_bits = sizeof_bits_v; - constexpr int qload = 128 / (element_bits / quant_bits); - TiledCopy g2s_copy_a = make_tiled_copy(num_threads); - TiledCopy g2s_copy_b = make_tiled_copy(num_threads); - TiledCopy s2g_copy_c = make_tiled_copy(num_threads); - - Copy_Atom s2r_atom_a; - Copy_Atom>, Quant> s2r_atom_b; - Copy_Atom>, Element> r2s_atom_c; - Copy_Atom, Scale> g2r_atom_s; - - auto* kernel = &qmm_sm80_kernel< - decltype(prob_shape), decltype(cta_tiler), - Element, Quant, Scale, - decltype(dA), decltype(sA_layout), decltype(g2s_copy_a), decltype(s2r_atom_a), - decltype(dB), decltype(sB_layout), decltype(g2s_copy_b), decltype(s2r_atom_b), - decltype(dC), decltype(sC_layout), decltype(s2g_copy_c), decltype(r2s_atom_c), - decltype(S_layout), decltype(g2r_atom_s), decltype(mma)>; - - // Set L1 to be SMEM only. + // Shared memory size. + auto [sA_layout, sB_layout, sC_layout] = make_smem_layouts(cta_tiler); size_t smem_bytes = sizeof(SharedStorage); + + auto* kernel = &qmm_sm80_kernel< + Element, Quant, Scale, + decltype(shape_MNKL), + decltype(cta_tiler), + decltype(dA), + decltype(dB), + decltype(S_layout), + decltype(dC), + decltype(mma)>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l); - dim3 block_dims(num_threads); + dim3 num_blocks{uint32_t(ceil_div(m, size<0>(cta_tiler))), + uint32_t(ceil_div(n, size<1>(cta_tiler))), + uint32_t(l)}; + dim3 block_dims{num_threads}; void* args[] = { - &prob_shape, &cta_tiler, - &A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a, - &B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b, - &C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c, - &S, &Z, &S_layout, &g2r_atom_s, + &shape_MNKL, &cta_tiler, + &A, &dA, + &B, &dB, + &S, &Z, &S_layout, &lhs_indices, &rhs_indices, + &C, &dC, &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -380,59 +154,6 @@ void qmm_sm80( namespace mlx::core { -template -inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { - if (dtype == float16) { - f.template operator()(); - } else if (dtype == bfloat16) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); - } -} - -template -inline void dispatch_groups(int group_size, const char* tag, F&& f) { - if (group_size == 32) { - f.template operator()<32>(); - } else if (group_size == 64) { - f.template operator()<64>(); - } else if (group_size == 128) { - f.template operator()<128>(); - } else { - throw std::invalid_argument( - fmt::format("{} Group size {} is not supported.", tag, group_size)); - } -} - -template -inline void dispatch_quant_types( - int bits, - int group_size, - QuantizationMode mode, - const char* tag, - F&& f) { - if (mode == QuantizationMode::Mxfp4) { - f.template operator()(); - } else if (mode == QuantizationMode::Mxfp8) { - f.template operator()(); - } else if (mode == QuantizationMode::Nvfp4) { - f.template operator()(); - } else { - dispatch_groups(group_size, tag, [&]() { - if (bits == 4) { - f.template operator()(); - } else if (bits == 8) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} {}-bit quantization is not supported.", tag, bits)); - } - }); - } -} - template void qmm_sm80_impl( const array& x, @@ -490,7 +211,7 @@ void qmm_sm80_impl( [&](auto* kernel, dim3 num_blocks, dim3 block_dims, - uint32_t smem_bytes, + size_t smem_bytes, void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); diff --git a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh new file mode 100644 index 0000000000..fdeceaab5f --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh @@ -0,0 +1,346 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" +#include "mlx/dtype_utils.h" + +// clang-format off + +// We can't put kernel code in mlx::core due to name conflicts of "Shape". +namespace cutlass_gemm { + +using namespace cute; + +template +union SharedStorage { + struct { + ArrayEngine> A; + ArrayEngine> B; + } mainloop; + struct { + ArrayEngine> C; + } epilogue; +}; + +inline constexpr auto make_smem_layouts(auto cta_tiler) { + // Define the A/B smem layouts (static). + auto swizzle_ab = composition(Swizzle<3,3,3>{}, + Layout>, + Stride<_8,Stride<_1,_64>>>{}); + auto [bM, bN, bK] = cta_tiler; + auto bP = Int<3>{}; // pipeline + auto sA_layout = tile_to_shape(swizzle_ab, make_shape(bM, bK, bP)); + auto sB_layout = tile_to_shape(swizzle_ab, make_shape(bN, bK, bP)); + + // Define the C smem layouts (static). + // TODO: Find a better swizzle. + auto sC_layout = tile_to_shape(swizzle_ab, make_shape(bM, bN)); + + return std::make_tuple(sA_layout, sB_layout, sC_layout); +} + +template typename Atom> +inline constexpr auto make_tiled_copy(auto num_threads) { + return make_tiled_copy( + Copy_Atom>, T>{}, + make_layout(make_shape(Int{}, Int<8>{}), LayoutRight{}), + make_layout(make_shape(Int<1>{}, Int>{}))); +} + +template +CUTE_DEVICE void qmm_sm80_mainloop( + CtaTiler cta_tiler, + TensorA gA, + TensorB gB, + TensorS gS, + TensorZ gZ, + TensorC gC, + TiledMma mma, + int m_max_coord, + int thread_idx) { + // Get the types of operands. + using Element = decltype(gA)::value_type; + using Quant = decltype(gB)::value_type; + using Scale = decltype(gS)::value_type; + + // Define smem layouts. + auto [sA_layout, sB_layout, sC_layout] = make_smem_layouts(cta_tiler); + + // Shared memory buffer. + extern __shared__ char smem_buf[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(smem.mainloop.A.begin()), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smem.mainloop.B.begin()), sB_layout); // (BLK_N,BLK_K) + Tensor sC = make_tensor(make_smem_ptr(smem.epilogue.C.begin()), sC_layout); // (BLK_M,BLK_N) + + // Define copy atoms. + constexpr int element_bits = sizeof_bits_v; + constexpr int quant_bits = sizeof_bits_v; + constexpr int qload = 128 / (element_bits / quant_bits); + auto num_threads = size(mma); + TiledCopy g2s_copy_a = make_tiled_copy(num_threads); + TiledCopy g2s_copy_b = make_tiled_copy(num_threads); + TiledCopy s2g_copy_c = make_tiled_copy(num_threads); + + Copy_Atom s2r_atom_a; + Copy_Atom>, Quant> s2r_atom_b; + Copy_Atom>, Element> r2s_atom_c; + Copy_Atom, Scale> g2r_atom_s; + + // Partition the copying of A/B/C tiles across the threads. + ThrCopy g2s_thr_copy_a = g2s_copy_a.get_slice(thread_idx); + Tensor tAgA = g2s_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = g2s_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + + ThrCopy g2s_thr_copy_b = g2s_copy_b.get_slice(thread_idx); + Tensor tBgB = g2s_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = g2s_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + ThrCopy s2g_thr_copy_c = s2g_copy_c.get_slice(thread_idx); + Tensor s2g_tCsC = s2g_thr_copy_c.partition_S(sC); // (CCPY,CCPY_M,CCPY_N) + Tensor s2g_tCgC = s2g_thr_copy_c.partition_D(gC); // (CCPY,CCPY_M,CCPY_N) + + // MMA. + ThrMMA thr_mma = mma.get_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + Tensor tCrB = make_fragment_like(tCsB); // (MMA,MMA_N,MMA_K) + Tensor tCrB_dq = make_fragment_like(tCsB); // (MMA,MMA_N,MMA_K) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + Tensor tCrC_accu = make_fragment_like(tCgC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = make_fragment_like(tCgC); // (MMA,MMA_M,MMA_N) + + Tensor tCgS = thr_mma.partition_B(gS); // (MMA,MMA_N,MMA_K,k) + Tensor tCrS = make_tensor_like(tCgS(_,_,_,0)); // (MMA,MMA_N,MMA_K) + Tensor tCgZ = thr_mma.partition_B(gZ); // (MMA,MMA_N,MMA_K,k) + Tensor tCrZ = make_tensor_like(tCgZ(_,_,_,0)); // (MMA,MMA_N,MMA_K) + + // Copy Atom retiling. + TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); + ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(thread_idx); + Tensor s2r_tCsA = s2r_thr_copy_a.partition_S(sA); // (ACPY,MMA_M,MMA_K,PIPE) + Tensor s2r_tCrA = s2r_thr_copy_a.retile_D(tCrA); // (ACPY,MMA_M,MMA_K) + + TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); + ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(thread_idx); + Tensor s2r_tCsB = s2r_thr_copy_b.partition_S(sB); // (BCPY,MMA_N,MMA_K,PIPE) + Tensor s2r_tCrB = s2r_thr_copy_b.retile_D(tCrB); // (BCPY,MMA_N,MMA_K) + + TiledCopy r2s_copy_c = make_tiled_copy_C(r2s_atom_c, mma); + ThrCopy r2s_thr_copy_c = r2s_copy_c.get_slice(thread_idx); + Tensor r2s_tCrC = r2s_thr_copy_c.retile_S(tCrC); // (CCPY,MMA_M,MMA_N) + Tensor r2s_tCsC = r2s_thr_copy_c.partition_D(sC); // (CCPY,MMA_M,MMA_N) + + TiledCopy g2r_copy_s = make_tiled_copy_B(g2r_atom_s, mma); + ThrCopy g2r_thr_copy_s = g2r_copy_s.get_slice(thread_idx); + Tensor g2r_tCgS = g2r_thr_copy_s.partition_S(gS); // (BCPY,MMA_N,MMA_K,k) + Tensor g2r_tCrS = g2r_thr_copy_s.retile_D(tCrS); // (BCPY,MMA_N,MMA_K) + Tensor g2r_tCgZ = g2r_thr_copy_s.partition_S(gZ); // (BCPY,MMA_N,MMA_K,k) + Tensor g2r_tCrZ = g2r_thr_copy_s.retile_D(tCrZ); // (BCPY,MMA_N,MMA_K) + + // Predicates for m bound. + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) + Tensor tCpC = make_tensor(make_shape(size<1>(s2g_tCsC), size<2>(s2g_tCsC)), Stride<_1,_0>{}); // (CPY_M,CPY_N) + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) + Tensor cC = make_identity_tensor(make_shape(size<0>(sC), size<1>(sC))); // (BLK_M,BLK_N) + Tensor tAcA = g2s_thr_copy_a.partition_D(cA); // (CPY,CPY_M,CPY_K) + Tensor tCcC = s2g_thr_copy_c.partition_D(cC); // (CPY,CPY_M,CPY_N) + CUTE_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; + } + CUTE_UNROLL + for (int m = 0; m < size<0>(tCpC); ++m) { + tCpC(m,0) = get<0>(tCcC(0,m,0)) < m_max_coord; + } + + auto K_PIPE_MAX = size<3>(tAsA); + int smem_pipe_read = 0; + int smem_pipe_write = 0; + + // Copy A/B: GMEM => SMEM. + auto fetch_gmem = [&](int tile) { + copy_if(g2s_copy_a, tApA, tAgA(_,_,_,tile), tAsA(_,_,_,smem_pipe_write)); + copy(g2s_copy_b, tBgB(_,_,_,tile), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + smem_pipe_write = (smem_pipe_write + 1) % K_PIPE_MAX; + }; + // Copy S/Z: GMEM => RMEM. + auto fetch_scales = [&](int tile) { + copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS); + if constexpr (quant_has_bias_v) { + copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); + } + }; + // Copy A/B: SMEM => RMEM. + auto fetch_smem = [&](auto block) { + copy(s2r_atom_a, s2r_tCsA(_,_,block,smem_pipe_read), s2r_tCrA(_,_,block)); + copy(s2r_atom_b, s2r_tCsB(_,_,block,smem_pipe_read), s2r_tCrB(_,_,block)); + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { + cute_vectorized_dequant( + tCrB(_,n,block), + tCrS(_,n,block), + tCrZ(_,n,block), + tCrB_dq(_,n,block)); + } + }; + + auto K_TILE_MAX = size<3>(tAgA); + auto K_BLOCK_MAX = size<2>(tCrA); + + // Prefetch beginning tiles. + int tile_pipe = 0; + CUTE_UNROLL + for (; tile_pipe < K_PIPE_MAX - 1; ++tile_pipe) { + fetch_gmem(tile_pipe); + } + + // Clear accumulators. + clear(tCrC_accu); + + // Prefetch first block. + if constexpr (K_BLOCK_MAX > 1) { + cp_async_wait(); + __syncthreads(); + fetch_scales(0); + fetch_smem(Int<0>{}); + } + + // Loop over CTA tiles. + for (int tile = 0; tile < K_TILE_MAX; ++tile) { + // Unroll MMA blocks. + CUTE_UNROLL + for (int block = 0; block < K_BLOCK_MAX; ++block) { + // Wait for last tile. + if (block == K_BLOCK_MAX - 1) { + smem_pipe_read = (smem_pipe_read + 1) % K_PIPE_MAX; + cp_async_wait(); + __syncthreads(); + fetch_scales((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + } + // Prefetch next block. + fetch_smem((block + 1) % K_BLOCK_MAX); + // Prefetch next tile. + if (block == 0) { + fetch_gmem(tile_pipe); + tile_pipe = (tile_pipe + 1 < K_TILE_MAX) ? tile_pipe + 1 : tile_pipe; + } + // MMA. + gemm(mma, tCrA(_,_,block), tCrB_dq(_,_,block), tCrC_accu); + } + } + + // Epilogue. + CUTE_UNROLL + for (int i = 0; i < size(tCrC_accu); i++) { + tCrC(i) = Element(tCrC_accu(i)); + } + copy(r2s_copy_c, r2s_tCrC, r2s_tCsC); + __syncthreads(); + copy_if(s2g_copy_c, tCpC, s2g_tCsC, s2g_tCgC); +} + +inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size) { + return make_layout( + make_shape(n, make_shape(group_size, k / group_size), l), + make_stride(k / group_size, Stride<_0,_1>{}, n * k / group_size)); +} + +template +inline constexpr auto make_cta_tiler(auto group_size) { + auto bM = Int{}; + auto bN = Int<128>{}; + auto bK = Int{}; + return make_shape(bM, bN, bK); +} + +template +inline constexpr auto make_tiled_mma() { + using Atom = std::conditional_t< + std::is_same_v, + SM80_16x8x16_F32F16F16F32_TN, + std::conditional_t< + std::is_same_v, + SM80_16x8x16_F32BF16BF16F32_TN, + UniversalFMA>>; + if constexpr (TileM >= 32) { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + } else { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + } +} + +} // namespace cutlass_gemm + +// clang-format on + +namespace mlx::core { + +template +inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { + if (dtype == float16) { + f.template operator()(); + } else if (dtype == bfloat16) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); + } +} + +template +inline void dispatch_groups(int group_size, const char* tag, F&& f) { + if (group_size == 32) { + f.template operator()<32>(); + } else if (group_size == 64) { + f.template operator()<64>(); + } else if (group_size == 128) { + f.template operator()<128>(); + } else { + throw std::invalid_argument( + fmt::format("{} Group size {} is not supported.", tag, group_size)); + } +} + +template +inline void dispatch_quant_types( + int bits, + int group_size, + QuantizationMode mode, + const char* tag, + F&& f) { + if (mode == QuantizationMode::Mxfp4) { + f.template operator()(); + } else if (mode == QuantizationMode::Mxfp8) { + f.template operator()(); + } else if (mode == QuantizationMode::Nvfp4) { + f.template operator()(); + } else { + dispatch_groups(group_size, tag, [&]() { + if (bits == 4) { + f.template operator()(); + } else if (bits == 8) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} {}-bit quantization is not supported.", tag, bits)); + } + }); + } +} + +} // namespace mlx::core