From 7c3501a55bbd98c15af7583d17d5737892f34224 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 8 Apr 2026 00:55:31 +0000 Subject: [PATCH 1/4] [cuda] initial Q1_0 backend --- ggml/src/ggml-cuda/common.cuh | 7 ++ ggml/src/ggml-cuda/convert.cu | 10 ++ ggml/src/ggml-cuda/dequantize.cuh | 23 +++++ ggml/src/ggml-cuda/getrows.cu | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 2 + ggml/src/ggml-cuda/mmq.cu | 9 ++ ggml/src/ggml-cuda/mmq.cuh | 93 +++++++++++++++++++ ggml/src/ggml-cuda/mmvq.cu | 8 ++ .../template-instances/generate_cu_files.py | 1 + .../template-instances/mmq-instance-q1_0.cu | 5 + ggml/src/ggml-cuda/vecdotq.cuh | 86 +++++++++++++++++ 11 files changed, 248 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9affe023403..6fa4383dedc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -918,6 +918,13 @@ struct ggml_cuda_type_traits { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_0; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 79ccfe568a2..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..05f6ffb47b7 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,28 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + const float neg_d = -d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d + const uint8_t bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const uint8_t bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = bit_0 ? d : neg_d; + v.y = bit_1 ? d : neg_d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..e99cba63d34 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75b62129ade..88ee9977be0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4785,6 +4785,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4822,6 +4823,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 27b4145ac9a..d04d836a618 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -301,6 +305,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t return false; } + // Q1_0 requires MMA (Turing+) — no DP4A fallback path + if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc)) { + return false; + } + if (turing_mma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 51e8dad4ce7..911bff93034 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -11,6 +11,7 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 +#define MMQ_ITER_K_Q1_0 128 // For Q1_0: QK1_0=128, QI1_0=4, so threads_per_row = 128/(4*4) = 8 #define MMQ_ITER_K_MXFP4_FP4 512 #define MMQ_NWARPS 8 @@ -57,6 +58,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) + GGML_UNUSED_VARS(x, x_tile, kbx0, i_max, stride, mmq_y, need_check); + NO_DEVICE_CODE; +#else + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q1_mmq_dp4a_disabled( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + // Q1_0 intentionally targets the MMA path only. + // If DP4A support is needed later for older GPUs, it should be reintroduced and validated separately. + GGML_UNUSED_VARS(x, y, sum, k00, mmq_x, mmq_y); + NO_DEVICE_CODE; +} + template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3274,6 +3359,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template struct mmq_type_traits; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_mmq_dp4a_disabled; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 07b10167bc4..8f55cace1a1 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; @@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; @@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type( const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 40d51f93fa4..841059c15b5 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40b2b41e7e8..2a661fc29d8 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -106,6 +106,47 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + +template static __device__ __forceinline__ float vec_dot_q1_0_q8_1_impl( + const int * v, const int * u, const float & d1, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi = v[i]; + + // Unpack 32 bits into 32 signed values (-1 or +1) + // Each bit: 0 -> -1, 1 -> +1 + int vi_bytes[8]; + +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (vi >> shift) & 0x0F; + + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + +#pragma unroll + for (int j = 0; j < 8; ++j) { + sumi = ggml_cuda_dp4a(vi_bytes[j], u[8*i + j], sumi); + } + } + + const float2 ds8f = __half22float2(ds8); + + // Q1_0 is symmetric (no offset), so we just multiply by scales + return d1 * ds8f.x * sumi; +} + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -669,6 +710,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float2 ds8f = __half22float2(bq8_1_chunk->ds); + return d1 * ds8f.x * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { From 84ab75f5e656fbe9ef50427fe35cce0e5ee5c0e8 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 8 Apr 2026 08:20:00 +0000 Subject: [PATCH 2/4] remove unused code, fix AMD MMA guard --- ggml/src/ggml-cuda/dequantize.cuh | 10 ++++---- ggml/src/ggml-cuda/mmq.cu | 4 ++-- ggml/src/ggml-cuda/mmq.cuh | 1 - ggml/src/ggml-cuda/vecdotq.cuh | 38 ------------------------------- 4 files changed, 7 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index 05f6ffb47b7..5f89ac401ec 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -15,12 +15,12 @@ static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const in const int byte_index_1 = bit_index_1 / 8; const int bit_offset_1 = bit_index_1 % 8; - // Extract bits: 1 = +d, 0 = -d - const uint8_t bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; - const uint8_t bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; - v.x = bit_0 ? d : neg_d; - v.y = bit_1 ? d : neg_d; + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; } static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index d04d836a618..bf2f7e19aff 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -305,8 +305,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t return false; } - // Q1_0 requires MMA (Turing+) — no DP4A fallback path - if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc)) { + // Q1_0 requires MMA — no DP4A fallback path + if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc) && !amd_mfma_available(cc) && !amd_wmma_available(cc)) { return false; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 911bff93034..b3b58a94467 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -11,7 +11,6 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 -#define MMQ_ITER_K_Q1_0 128 // For Q1_0: QK1_0=128, QI1_0=4, so threads_per_row = 128/(4*4) = 8 #define MMQ_ITER_K_MXFP4_FP4 512 #define MMQ_NWARPS 8 diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 2a661fc29d8..9a1ffa87b69 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -109,44 +109,6 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { #define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism #define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block -template static __device__ __forceinline__ float vec_dot_q1_0_q8_1_impl( - const int * v, const int * u, const float & d1, const half2 & ds8) { - - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - const int vi = v[i]; - - // Unpack 32 bits into 32 signed values (-1 or +1) - // Each bit: 0 -> -1, 1 -> +1 - int vi_bytes[8]; - -#pragma unroll - for (int j = 0; j < 8; ++j) { - const int shift = j * 4; - const int bits4 = (vi >> shift) & 0x0F; - - const int b0 = (bits4 & 0x01) ? 1 : -1; - const int b1 = (bits4 & 0x02) ? 1 : -1; - const int b2 = (bits4 & 0x04) ? 1 : -1; - const int b3 = (bits4 & 0x08) ? 1 : -1; - - vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); - } - -#pragma unroll - for (int j = 0; j < 8; ++j) { - sumi = ggml_cuda_dp4a(vi_bytes[j], u[8*i + j], sumi); - } - } - - const float2 ds8f = __half22float2(ds8); - - // Q1_0 is symmetric (no offset), so we just multiply by scales - return d1 * ds8f.x * sumi; -} - #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 From bca0c0b89508fa06ce03a87f6a258fe211eec9e6 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Thu, 9 Apr 2026 18:42:03 +0000 Subject: [PATCH 3/4] attempt to support dp4a --- ggml/src/ggml-cuda/dequantize.cuh | 1 - ggml/src/ggml-cuda/mmq.cu | 5 ----- ggml/src/ggml-cuda/mmq.cuh | 31 ++++++++++++++++--------------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index 5f89ac401ec..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -4,7 +4,6 @@ static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const in const block_q1_0 * x = (const block_q1_0 *) vx; const float d = x[ib].d; - const float neg_d = -d; const int bit_index_0 = iqs; const int bit_index_1 = iqs + 1; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index bf2f7e19aff..3f01ff5bfb0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -305,11 +305,6 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t return false; } - // Q1_0 requires MMA — no DP4A fallback path - if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc) && !amd_mfma_available(cc) && !amd_wmma_available(cc)) { - return false; - } - if (turing_mma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b3b58a94467..fbd33505082 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -187,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; @@ -307,15 +308,17 @@ static constexpr __device__ int mmq_get_nwarps_device() { template static __device__ __forceinline__ void load_tiles_q1_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { -#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) - GGML_UNUSED_VARS(x, x_tile, kbx0, i_max, stride, mmq_y, need_check); - NO_DEVICE_CODE; -#else constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; constexpr int threads_per_row = blocks_per_iter * QI1_0; @@ -355,7 +358,11 @@ template static __device__ __forceinline__ void loa const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; #pragma unroll for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif } } @@ -372,18 +379,12 @@ template static __device__ __forceinline__ void loa const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) -} - -template -static __device__ __forceinline__ void vec_dot_q1_mmq_dp4a_disabled( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { - // Q1_0 intentionally targets the MMA path only. - // If DP4A support is needed later for older GPUs, it should be reintroduced and validated separately. - GGML_UNUSED_VARS(x, y, sum, k00, mmq_x, mmq_y); - NO_DEVICE_CODE; } template static __device__ __forceinline__ void load_tiles_q4_0( @@ -3363,7 +3364,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_mmq_dp4a_disabled; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template From 05b0c84e65074f643ca752fa24bec51ab21fc8c5 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Fri, 10 Apr 2026 13:50:56 -0700 Subject: [PATCH 4/4] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmq.cuh | 4 ++-- ggml/src/ggml-cuda/vecdotq.cuh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index fbd33505082..2026b45bab2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -362,7 +362,7 @@ template static __device__ __forceinline__ void loa x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; -#endif +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -383,7 +383,7 @@ template static __device__ __forceinline__ void loa x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; -#endif +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 9a1ffa87b69..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -713,8 +713,8 @@ static __device__ __forceinline__ float vec_dot_q1_0_q8_1( } // Apply Q1_0's single scale and this chunk's Q8_1 scale - const float2 ds8f = __half22float2(bq8_1_chunk->ds); - return d1 * ds8f.x * sumi; + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; } static __device__ __forceinline__ float vec_dot_q4_0_q8_1(