diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9affe023403..6fa4383dedc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -918,6 +918,13 @@ struct ggml_cuda_type_traits { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_0; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 79ccfe568a2..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,27 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..e99cba63d34 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75b62129ade..88ee9977be0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4785,6 +4785,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4822,6 +4823,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 27b4145ac9a..3f01ff5bfb0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 51e8dad4ce7..2026b45bab2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; @@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3274,6 +3359,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template struct mmq_type_traits; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 07b10167bc4..8f55cace1a1 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; @@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; @@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type( const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 40d51f93fa4..841059c15b5 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40b2b41e7e8..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -669,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {