Skip to content

Commit 3d2f7cf

Browse files
committed
ggml-cpu: refactor; add rvv repacking for q5_K
1 parent f12a065 commit 3d2f7cf

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

ggml/src/ggml-cpu/arch/riscv/repack.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,15 +1065,15 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
10651065

10661066
// Accumulation for 2 sub-blocks.
10671067
//
1068-
// This might overflow, so we accumulate in two steps.
1068+
// This might overflow, so we accumulate in 4 steps.
10691069
//
10701070
// Recheck.
1071-
for (int k = 0; k < 2; k++) {
1071+
for (int k = 0; k < 4; k++) {
10721072
// 4xM integer accumulators
10731073
vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
10741074
vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
10751075

1076-
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
1076+
for (int i = k * 8; i < (k + 1) * 8; i++) {
10771077
// Load `b_ptr`.
10781078
const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved);
10791079
const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved));
@@ -1099,15 +1099,15 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
10991099
}
11001100
// Accumulation for 2 sub-blocks.
11011101
//
1102-
// This might overflow, so we accumulate in two steps.
1102+
// This might overflow, so we accumulate in 4 steps.
11031103
//
11041104
// Recheck.
1105-
for (int k = 0; k < 2; k++) {
1105+
for (int k = 0; k < 4; k++) {
11061106
// 4xM integer accumulators
11071107
vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
11081108
vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
11091109

1110-
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
1110+
for (int i = k * 8; i < (k + 1) * 8; i++) {
11111111
// Load `b_ptr`.
11121112
const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved);
11131113
const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved));
@@ -1202,7 +1202,7 @@ void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
12021202
for (int l = 0; l < nb; l++) {
12031203
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
12041204

1205-
// We process 2 16-element sub-blocks at once.
1205+
// We process 4 16-element sub-blocks at once.
12061206
for (int j = 0; j < QK_K / 16; j += 4) {
12071207
// Load the scales.
12081208
//
@@ -2225,7 +2225,7 @@ static void ggml_gemm_q3_K_Mx1_q8_K(int n,
22252225
for (int group = 0; group < 4; ++group) {
22262226
// High scales are needed for all 4 sub-blocks (0.5 register)
22272227
vuint8mf2_t v_sc_h_quad = __riscv_vle8_v_u8mf2(rhs_sc_high_ptr, 16);
2228-
rhs_sc_high_ptr += ncols_interleaved;
2228+
rhs_sc_high_ptr += ncols_interleaved;
22292229

22302230
// --- Scope 1: Sub-blocks 1 & 2 (Pair 0) ---
22312231
// By scoping this, v_sc_l_pair0 dies before we load pair1
@@ -2926,10 +2926,10 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
29262926

29272927
// Accumulation for 2 sub-blocks.
29282928
//
2929-
// This might overflow, so we accumulate in two steps.
2929+
// This might overflow, so we accumulate in 4 steps.
29302930
//
29312931
// Recheck.
2932-
for (int k = 0; k < 2; k++) {
2932+
for (int k = 0; k < 4; k++) {
29332933
// 4xM integer accumulators
29342934
vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
29352935
vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
@@ -2940,7 +2940,7 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
29402940
vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
29412941
vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
29422942

2943-
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
2943+
for (int i = k * 8; i < (k + 1) * 8; i++) {
29442944
// Load `b_ptr`.
29452945
const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved);
29462946
const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved));
@@ -2994,7 +2994,7 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
29942994
// This might overflow, so we accumulate in two steps.
29952995
//
29962996
// Recheck.
2997-
for (int k = 0; k < 2; k++) {
2997+
for (int k = 0; k < 4; k++) {
29982998
// 4xM integer accumulators
29992999
vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
30003000
vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
@@ -3005,7 +3005,7 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
30053005
vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
30063006
vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
30073007

3008-
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
3008+
for (int i = k * 8; i < (k + 1) * 8; i++) {
30093009
// Load `b_ptr`.
30103010
const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved);
30113011
const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved));

0 commit comments

Comments
 (0)