Skip to content

Commit cf95828

Browse files
committed
ggml-cpu: refactor; add rvv checks
1 parent 4b12d40 commit cf95828

1 file changed

Lines changed: 31 additions & 8 deletions

File tree

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
121121

122122
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
123123
assert(k % QK_K == 0);
124-
block_q8_K * y_blocks = (block_q8_K *)y;
125124
size_t nb = k / QK_K;
126125

127-
#if defined(__riscv_v_intrinsic)
126+
#if defined __riscv_v_intrinsic
127+
block_q8_K * y_blocks = (block_q8_K *)y;
128128
const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8();
129129

130130
for (size_t i = 0; i < nb; i++) {
@@ -2058,6 +2058,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
20582058
#endif
20592059
}
20602060

2061+
#if defined __riscv_v_intrinsic
20612062
static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
20622063
assert(n % QK_K == 0);
20632064
assert(nrc == 1);
@@ -2265,6 +2266,7 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT
22652266

22662267
*s = sumf;
22672268
}
2269+
#endif
22682270

22692271
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
22702272
#if defined __riscv_v_intrinsic
@@ -2284,6 +2286,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
22842286
#endif
22852287
}
22862288

2289+
#if defined __riscv_v_intrinsic
22872290
static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
22882291
assert(n % QK_K == 0);
22892292
assert(nrc == 1);
@@ -2563,6 +2566,7 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT
25632566

25642567
*s = sumf;
25652568
}
2569+
#endif
25662570

25672571
void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
25682572
#if defined __riscv_v_intrinsic
@@ -2582,6 +2586,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
25822586
#endif
25832587
}
25842588

2589+
#if defined __riscv_v_intrinsic
25852590
static const uint8_t sign_gather_indices_arr[64] = {
25862591
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
25872592
4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
@@ -2784,6 +2789,7 @@ static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT
27842789
}
27852790
*s = 0.125f * sumf;
27862791
}
2792+
#endif
27872793

27882794
void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
27892795
#if defined __riscv_v_intrinsic
@@ -2803,7 +2809,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
28032809
#endif
28042810
}
28052811

2806-
#if defined(__riscv_v)
2812+
#if defined __riscv_v_intrinsic
28072813
static const int8_t keven_signs_q2xs[1024] = {
28082814
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
28092815
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
@@ -2838,7 +2844,6 @@ static const int8_t keven_signs_q2xs[1024] = {
28382844
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
28392845
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
28402846
};
2841-
#endif
28422847

28432848
static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
28442849
assert(n % QK_K == 0);
@@ -2991,6 +2996,7 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT
29912996
}
29922997
*s = 0.125f * sumf;
29932998
}
2999+
#endif
29943000

29953001
void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
29963002
#if defined __riscv_v_intrinsic
@@ -3010,6 +3016,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
30103016
#endif
30113017
}
30123018

3019+
#if defined __riscv_v_intrinsic
30133020
static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
30143021
assert(n % QK_K == 0);
30153022
assert(nrc == 1);
@@ -3194,6 +3201,7 @@ static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRIC
31943201
}
31953202
*s = 0.125f * sumf;
31963203
}
3204+
#endif
31973205

31983206
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
31993207
#if defined __riscv_v_intrinsic
@@ -3206,10 +3214,11 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
32063214
break;
32073215
}
32083216
#else
3209-
ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
3217+
ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
32103218
#endif
32113219
}
32123220

3221+
#if defined __riscv_v_intrinsic
32133222
static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
32143223
assert(n % QK_K == 0);
32153224
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
@@ -3399,6 +3408,7 @@ static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT
33993408
}
34003409
*s = sumf;
34013410
}
3411+
#endif
34023412

34033413
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
34043414
#if defined __riscv_v_intrinsic
@@ -3418,6 +3428,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
34183428
#endif
34193429
}
34203430

3431+
#if defined __riscv_v_intrinsic
34213432
static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
34223433
assert(n % QK_K == 0);
34233434
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
@@ -3603,6 +3614,7 @@ static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRIC
36033614
}
36043615
*s = 0.25f * sumf;
36053616
}
3617+
#endif
36063618

36073619
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
36083620
#if defined __riscv_v_intrinsic
@@ -3622,6 +3634,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
36223634
#endif
36233635
}
36243636

3637+
#if defined __riscv_v_intrinsic
36253638
static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
36263639
assert(nrc == 1);
36273640
UNUSED(nrc);
@@ -3733,6 +3746,7 @@ static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT
37333746

37343747
*s = sumf;
37353748
}
3749+
#endif
37363750

37373751
void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
37383752
#if defined __riscv_v_intrinsic
@@ -3749,6 +3763,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
37493763
#endif
37503764
}
37513765

3766+
#if defined __riscv_v_intrinsic
37523767
static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
37533768
assert(nrc == 1);
37543769
UNUSED(nrc);
@@ -3894,6 +3909,7 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT
38943909

38953910
*s = sumf;
38963911
}
3912+
#endif
38973913

38983914
void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
38993915
#if defined __riscv_v_intrinsic
@@ -3913,6 +3929,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
39133929
#endif
39143930
}
39153931

3932+
#if defined __riscv_v_intrinsic
39163933
static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
39173934
assert(nrc == 1);
39183935
UNUSED(nrc);
@@ -4106,14 +4123,16 @@ static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT
41064123
suml3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
41074124
}
41084125

4109-
vint16m2_t sumb = __riscv_vadd_vv_i16m2(suml1, __riscv_vlmul_ext_v_i16m1_i16m2(__riscv_vadd_vv_i16m1(suml2, suml3, 16)), 16);
4126+
vint16m1_t sumb = __riscv_vadd_vv_i16m1(__riscv_vget_v_i16m2_i16m1(suml1, 0), __riscv_vget_v_i16m2_i16m1(suml1, 1), 16);
4127+
sumb = __riscv_vadd_vv_i16m1(sumb, __riscv_vadd_vv_i16m1(suml2, suml3, 16), 16);
41104128

4111-
vint32m1_t sum = __riscv_vwredsum_vs_i16m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 32);
4129+
vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);
41124130
sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
41134131
}
41144132

41154133
*s = sumf;
41164134
}
4135+
#endif
41174136

41184137
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
41194138
#if defined __riscv_v_intrinsic
@@ -4130,6 +4149,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
41304149
#endif
41314150
}
41324151

4152+
#if defined __riscv_v_intrinsic
41334153
static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
41344154
assert(n % QK_K == 0);
41354155
assert(nrc == 1);
@@ -4282,6 +4302,7 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT
42824302

42834303
*s = sumf;
42844304
}
4305+
#endif
42854306

42864307
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
42874308
#if defined __riscv_v_intrinsic
@@ -4301,6 +4322,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
43014322
#endif
43024323
}
43034324

4325+
#if defined __riscv_v_intrinsic
43044326
static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
43054327
assert(nrc == 1);
43064328
UNUSED(nrc);
@@ -4412,6 +4434,7 @@ static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT
44124434

44134435
*s = sumf;
44144436
}
4437+
#endif
44154438

44164439
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
44174440
#if defined __riscv_v_intrinsic
@@ -4424,6 +4447,6 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
44244447
break;
44254448
}
44264449
#else
4427-
return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
4450+
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
44284451
#endif
44294452
}

0 commit comments

Comments
 (0)