Skip to content

Commit 51b400b

Browse files
committed
ggml-cpu: add 128-bit impls for i-quants, ternary quants
1 parent 3d0b226 commit 51b400b

File tree

1 file changed

+0
-201
lines changed

1 file changed

+0
-201
lines changed

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 0 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -3628,201 +3628,6 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT
36283628
*s = sumf;
36293629
}
36303630

3631-
static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3632-
assert(nrc == 1);
3633-
UNUSED(nrc);
3634-
UNUSED(bx);
3635-
UNUSED(by);
3636-
UNUSED(bs);
3637-
assert(n % QK_K == 0);
3638-
3639-
const block_iq4_xs * GGML_RESTRICT x = vx;
3640-
const block_q8_K * GGML_RESTRICT y = vy;
3641-
3642-
const int nb = n / QK_K;
3643-
3644-
const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16);
3645-
float sumf = 0;
3646-
3647-
// Indices for re-ordering IQ4 data.
3648-
const uint16_t index[32] = {
3649-
0, 1, 16, 17,
3650-
2, 3, 18, 19,
3651-
4, 5,20, 21,
3652-
6, 7, 22, 23,
3653-
8, 9, 24, 25,
3654-
10, 11, 26, 27,
3655-
12, 13,28, 29,
3656-
14, 15, 30, 31,
3657-
};
3658-
const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32);
3659-
3660-
for (int ibl = 0; ibl < nb; ++ibl) {
3661-
const int8_t * q8 = y[ibl].qs;
3662-
const uint8_t * iq4 = x[ibl].qs;
3663-
uint16_t h = x[ibl].scales_h;
3664-
3665-
int sumi = 0;
3666-
3667-
#pragma GCC unroll 1
3668-
// Process the entire super-block together.
3669-
for (int ib = 0; ib < QK_K / 256; ++ib) {
3670-
// Weights and activations.
3671-
const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128);
3672-
iq4 += 128;
3673-
3674-
// Unpack the weight blocks.
3675-
const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128);
3676-
const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128);
3677-
const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi);
3678-
const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32));
3679-
const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256);
3680-
3681-
__asm__ __volatile__("" ::: "memory");
3682-
3683-
// Multiply with activations.
3684-
const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256);
3685-
const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256);
3686-
q8 += 256;
3687-
3688-
// Reduce separately.
3689-
const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32));
3690-
const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32));
3691-
const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32));
3692-
const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32));
3693-
const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32));
3694-
const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32));
3695-
const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32));
3696-
const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32));
3697-
3698-
3699-
const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32;
3700-
const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32;
3701-
const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32;
3702-
const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32;
3703-
h >>= 8;
3704-
const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32;
3705-
const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32;
3706-
const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32;
3707-
const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32;
3708-
3709-
sumi += acc0 * ls0;
3710-
sumi += acc1 * ls1;
3711-
sumi += acc2 * ls2;
3712-
sumi += acc3 * ls3;
3713-
sumi += acc4 * ls4;
3714-
sumi += acc5 * ls5;
3715-
sumi += acc6 * ls6;
3716-
sumi += acc7 * ls7;
3717-
3718-
__asm__ __volatile__("" ::: "memory");
3719-
}
3720-
3721-
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi);
3722-
}
3723-
3724-
*s = sumf;
3725-
}
3726-
3727-
static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3728-
assert(nrc == 1);
3729-
UNUSED(nrc);
3730-
UNUSED(bx);
3731-
UNUSED(by);
3732-
UNUSED(bs);
3733-
assert(n % QK_K == 0);
3734-
3735-
const block_iq4_xs * GGML_RESTRICT x = vx;
3736-
const block_q8_K * GGML_RESTRICT y = vy;
3737-
3738-
const int nb = n / QK_K;
3739-
3740-
const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16);
3741-
float sumf = 0;
3742-
3743-
// Indices for re-ordering IQ4 data.
3744-
const uint16_t index[32] = {
3745-
0, 1, 16, 17,
3746-
2, 3, 18, 19,
3747-
4, 5,20, 21,
3748-
6, 7, 22, 23,
3749-
8, 9, 24, 25,
3750-
10, 11, 26, 27,
3751-
12, 13,28, 29,
3752-
14, 15, 30, 31,
3753-
};
3754-
const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32);
3755-
3756-
for (int ibl = 0; ibl < nb; ++ibl) {
3757-
const int8_t * q8 = y[ibl].qs;
3758-
const uint8_t * iq4 = x[ibl].qs;
3759-
uint16_t h = x[ibl].scales_h;
3760-
3761-
int sumi = 0;
3762-
3763-
#pragma GCC unroll 1
3764-
// Process the entire super-block together.
3765-
for (int ib = 0; ib < QK_K / 256; ++ib) {
3766-
// Weights and activations.
3767-
const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128);
3768-
iq4 += 128;
3769-
3770-
// Unpack the weight blocks.
3771-
const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128);
3772-
const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128);
3773-
const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi);
3774-
const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32));
3775-
const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256);
3776-
3777-
__asm__ __volatile__("" ::: "memory");
3778-
3779-
// Multiply with activations.
3780-
const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256);
3781-
const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256);
3782-
q8 += 256;
3783-
3784-
// Mask for processing 32 elements per prod register.
3785-
const vuint16m1_t p_index = __riscv_vid_v_u16m1(64);
3786-
const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64);
3787-
3788-
// Reduce separately.
3789-
const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32));
3790-
const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64));
3791-
const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32));
3792-
const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64));
3793-
const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32));
3794-
const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64));
3795-
const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32));
3796-
const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64));
3797-
3798-
const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32;
3799-
const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32;
3800-
const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32;
3801-
const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32;
3802-
h >>= 8;
3803-
const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32;
3804-
const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32;
3805-
const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32;
3806-
const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32;
3807-
3808-
sumi += acc0 * ls0;
3809-
sumi += acc1 * ls1;
3810-
sumi += acc2 * ls2;
3811-
sumi += acc3 * ls3;
3812-
sumi += acc4 * ls4;
3813-
sumi += acc5 * ls5;
3814-
sumi += acc6 * ls6;
3815-
sumi += acc7 * ls7;
3816-
3817-
__asm__ __volatile__("" ::: "memory");
3818-
}
3819-
3820-
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi);
3821-
}
3822-
3823-
*s = sumf;
3824-
}
3825-
38263631
void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
38273632
#if defined __riscv_v_intrinsic
38283633
switch (__riscv_vlenb() * 8) {
@@ -3832,12 +3637,6 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
38323637
case 256:
38333638
ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
38343639
break;
3835-
case 512:
3836-
ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
3837-
break;
3838-
case 1024:
3839-
ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc);
3840-
break;
38413640
default:
38423641
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
38433642
break;

0 commit comments

Comments
 (0)