Skip to content

Commit bd69a20

Browse files
RehanQasim-devtaimur-10x
authored andcommitted
ggml-cpu: improve iq2_xs impl for rvv 256b
1 parent 3326b3e commit bd69a20

File tree

2 files changed

+70
-52
lines changed

2 files changed

+70
-52
lines changed

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@
237237
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
238238
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
239239
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
240-
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
240+
// #define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
241241
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
242242
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
243243
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
@@ -285,7 +285,7 @@
285285
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
286286
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
287287
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
288-
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
288+
// #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
289289
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
290290
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
291291
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 68 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3800,8 +3800,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
38003800
case 256:
38013801
ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
38023802
break;
3803+
case 512:
3804+
ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
3805+
break;
38033806
default:
3804-
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3807+
ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc);
38053808
break;
38063809
}
38073810
#else
@@ -3946,11 +3949,7 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT
39463949

39473950
static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
39483951
assert(n % QK_K == 0);
3949-
assert(nrc == 1);
3950-
UNUSED(nrc);
3951-
UNUSED(bx);
3952-
UNUSED(by);
3953-
UNUSED(bs);
3952+
(void)nrc; (void)bx; (void)by; (void)bs;
39543953

39553954
const block_iq2_xs * GGML_RESTRICT x = vx;
39563955
const block_q8_K * GGML_RESTRICT y = vy;
@@ -3969,61 +3968,74 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT
39693968

39703969
int32_t sum_int = 0;
39713970

3972-
// Loop over 4 subblocks of 64 elements (QK_K = 256)
3973-
for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) {
3974-
// Load 8 uint16 indices (controls 64 values)
3975-
vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8);
3976-
qs += 8;
3971+
for (int ib128 = 0; ib128 < 2; ++ib128) {
3972+
3973+
vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16);
3974+
qs += 16;
39773975

3978-
// Extract indices for grid (low 9 bits) and signs (high 7 bits)
3979-
// Multiply by 8 (<< 3) for byte offsets into the uint64 tables
3980-
vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8);
3981-
vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8);
3976+
// Prepare offsets for grid and signs
3977+
vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16);
3978+
vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16);
39823979

3983-
vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8);
3984-
vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8);
3980+
// Indexed load 128 weights (16 x 8-byte chunks)
3981+
vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16);
3982+
vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16);
39853983

3986-
vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64));
3987-
vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64));
3984+
vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64));
3985+
vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64));
39883986

3989-
// Apply signs
3990-
vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64);
3987+
// Apply signs to get dequantized IQ2 values
3988+
vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128);
3989+
asm volatile("" ::: "memory");
39913990

3992-
// Load Q8 weights (64 elements)
3993-
vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64);
3994-
q8 += 64;
3991+
// Load corresponding Q8 weights
3992+
vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128);
3993+
q8 += 128;
39953994

3996-
// Multiply (Widening to int16, 64 elements -> LMUL=4)
3997-
vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64);
3995+
vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128);
3996+
asm volatile("" ::: "memory");
3997+
3998+
uint8_t sc0 = scales[0];
3999+
uint8_t sc1 = scales[1];
4000+
uint8_t sc2 = scales[2];
4001+
uint8_t sc3 = scales[3];
4002+
scales += 4;
39984003

3999-
// Reduction
40004004
vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
40014005

4002-
int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
4003-
__riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16));
4004-
int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
4005-
__riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16));
4006-
int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
4007-
__riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16));
4008-
int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
4009-
__riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16));
4006+
// 9. Reduce each 16-element chunk and apply corresponding nibble scale
40104007

4011-
// Apply Scales
4012-
const uint8_t scale_byte_1 = scales[0];
4013-
const uint8_t scale_byte_2 = scales[1];
4014-
scales += 2;
4008+
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16));
4009+
sum_int += s0 * ((sc0 & 0x0F) * 2 + 1);
40154010

4016-
sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1);
4017-
sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1);
4018-
sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1);
4019-
sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1);
4011+
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16));
4012+
sum_int += s1 * ((sc0 >> 4) * 2 + 1);
4013+
4014+
int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16));
4015+
sum_int += s2 * ((sc1 & 0x0F) * 2 + 1);
4016+
4017+
int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16));
4018+
sum_int += s3 * ((sc1 >> 4) * 2 + 1);
4019+
4020+
int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16));
4021+
sum_int += s4 * ((sc2 & 0x0F) * 2 + 1);
4022+
4023+
int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16));
4024+
sum_int += s5 * ((sc2 >> 4) * 2 + 1);
4025+
4026+
int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16));
4027+
sum_int += s6 * ((sc3 & 0x0F) * 2 + 1);
4028+
4029+
int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16));
4030+
sum_int += s7 * ((sc3 >> 4) * 2 + 1);
40204031
}
40214032

4022-
sumf += d * sum_int;
4033+
sumf += d * (float)sum_int;
40234034
}
40244035
*s = 0.125f * sumf;
40254036
}
40264037

4038+
40274039
static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
40284040
assert(n % QK_K == 0);
40294041
assert(nrc == 1);
@@ -4099,7 +4111,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
40994111
ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
41004112
break;
41014113
default:
4102-
ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4114+
ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
41034115
break;
41044116
}
41054117
#else
@@ -4371,9 +4383,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
43714383
case 128:
43724384
ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
43734385
break;
4374-
default:
4386+
case 256:
43754387
ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
43764388
break;
4389+
default:
4390+
ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
4391+
break;
43774392
}
43784393
#else
43794394
ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
@@ -4665,13 +4680,13 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
46654680
#if defined __riscv_v_intrinsic
46664681
switch (__riscv_vlenb() * 8) {
46674682
case 128:
4668-
ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
4683+
ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
46694684
break;
46704685
case 256:
46714686
ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
46724687
break;
46734688
default:
4674-
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4689+
ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
46754690
break;
46764691
}
46774692
#else
@@ -5058,8 +5073,11 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
50585073
case 256:
50595074
ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
50605075
break;
5076+
case 512:
5077+
ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
5078+
break;
50615079
default:
5062-
ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
5080+
ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc);
50635081
break;
50645082
}
50655083
#else
@@ -5939,7 +5957,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
59395957
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
59405958
break;
59415959
default:
5942-
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
5960+
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
59435961
break;
59445962
}
59455963
#else

0 commit comments

Comments
 (0)