@@ -3800,8 +3800,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
38003800 case 256 :
38013801 ggml_vec_dot_iq2_s_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
38023802 break ;
3803+ case 512 :
3804+ ggml_vec_dot_iq2_s_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
3805+ break ;
38033806 default :
3804- ggml_vec_dot_iq2_s_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
3807+ ggml_vec_dot_iq2_s_q8_K_vl1024 (n , s , bs , vx , bx , vy , by , nrc );
38053808 break ;
38063809 }
38073810#else
@@ -3946,11 +3949,7 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT
39463949
39473950static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
39483951 assert (n % QK_K == 0 );
3949- assert (nrc == 1 );
3950- UNUSED (nrc );
3951- UNUSED (bx );
3952- UNUSED (by );
3953- UNUSED (bs );
3952+ (void )nrc ; (void )bx ; (void )by ; (void )bs ;
39543953
39553954 const block_iq2_xs * GGML_RESTRICT x = vx ;
39563955 const block_q8_K * GGML_RESTRICT y = vy ;
@@ -3969,61 +3968,74 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT
39693968
39703969 int32_t sum_int = 0 ;
39713970
3972- // Loop over 4 subblocks of 64 elements (QK_K = 256)
3973- for (int ib64 = 0 ; ib64 < QK_K / 64 ; ++ ib64 ) {
3974- // Load 8 uint16 indices (controls 64 values)
3975- vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2 (qs , 8 );
3976- qs += 8 ;
3971+ for (int ib128 = 0 ; ib128 < 2 ; ++ ib128 ) {
3972+
3973+ vuint16m1_t v_qs = __riscv_vle16_v_u16m1 (qs , 16 );
3974+ qs += 16 ;
39773975
3978- // Extract indices for grid (low 9 bits) and signs (high 7 bits)
3979- // Multiply by 8 (<< 3) for byte offsets into the uint64 tables
3980- vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2 (__riscv_vand_vx_u16mf2 (v_qs , 511 , 8 ), 3 , 8 );
3981- vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2 (__riscv_vsrl_vx_u16mf2 (v_qs , 9 , 8 ), 3 , 8 );
3976+ // Prepare offsets for grid and signs
3977+ vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1 (__riscv_vand_vx_u16m1 (v_qs , 511 , 16 ), 3 , 16 );
3978+ vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1 (__riscv_vsrl_vx_u16m1 (v_qs , 9 , 16 ), 3 , 16 );
39823979
3983- vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2 (grid64 , vidx_grid , 8 );
3984- vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2 (signs64 , vidx_sign , 8 );
3980+ // Indexed load 128 weights (16 x 8-byte chunks)
3981+ vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4 (grid64 , vidx_grid , 16 );
3982+ vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4 (signs64 , vidx_sign , 16 );
39853983
3986- vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2 ( __riscv_vreinterpret_v_u64m2_u8m2 (vq2_64 ));
3987- vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2 ( __riscv_vreinterpret_v_u64m2_u8m2 (vs2_64 ));
3984+ vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4 ( __riscv_vreinterpret_v_u64m4_u8m4 (vq2_64 ));
3985+ vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4 ( __riscv_vreinterpret_v_u64m4_u8m4 (vs2_64 ));
39883986
3989- // Apply signs
3990- vint8m2_t q2_final = __riscv_vmul_vv_i8m2 (q2u , q2s , 64 );
3987+ // Apply signs to get dequantized IQ2 values
3988+ vint8m4_t q2_final = __riscv_vmul_vv_i8m4 (q2u , q2s , 128 );
3989+ asm volatile ("" ::: "memory" );
39913990
3992- // Load Q8 weights (64 elements)
3993- vint8m2_t q8v = __riscv_vle8_v_i8m2 (q8 , 64 );
3994- q8 += 64 ;
3991+ // Load corresponding Q8 weights
3992+ vint8m4_t q8v = __riscv_vle8_v_i8m4 (q8 , 128 );
3993+ q8 += 128 ;
39953994
3996- // Multiply (Widening to int16, 64 elements -> LMUL=4)
3997- vint16m4_t prod = __riscv_vwmul_vv_i16m4 (q2_final , q8v , 64 );
3995+ vint16m8_t prod = __riscv_vwmul_vv_i16m8 (q2_final , q8v , 128 );
3996+ asm volatile ("" ::: "memory" );
3997+
3998+ uint8_t sc0 = scales [0 ];
3999+ uint8_t sc1 = scales [1 ];
4000+ uint8_t sc2 = scales [2 ];
4001+ uint8_t sc3 = scales [3 ];
4002+ scales += 4 ;
39984003
3999- // Reduction
40004004 vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1 (0 , 1 );
40014005
4002- int32_t sum0 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
4003- __riscv_vget_v_i16m4_i16m1 (prod , 0 ), zero_vec , 16 ));
4004- int32_t sum1 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
4005- __riscv_vget_v_i16m4_i16m1 (prod , 1 ), zero_vec , 16 ));
4006- int32_t sum2 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
4007- __riscv_vget_v_i16m4_i16m1 (prod , 2 ), zero_vec , 16 ));
4008- int32_t sum3 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
4009- __riscv_vget_v_i16m4_i16m1 (prod , 3 ), zero_vec , 16 ));
4006+ // 9. Reduce each 16-element chunk and apply corresponding nibble scale
40104007
4011- // Apply Scales
4012- const uint8_t scale_byte_1 = scales [0 ];
4013- const uint8_t scale_byte_2 = scales [1 ];
4014- scales += 2 ;
4008+ int32_t s0 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 0 ), zero_vec , 16 ));
4009+ sum_int += s0 * ((sc0 & 0x0F ) * 2 + 1 );
40154010
4016- sum_int += sum0 * ((scale_byte_1 & 0x0F ) * 2 + 1 );
4017- sum_int += sum1 * ((scale_byte_1 >> 4 ) * 2 + 1 );
4018- sum_int += sum2 * ((scale_byte_2 & 0x0F ) * 2 + 1 );
4019- sum_int += sum3 * ((scale_byte_2 >> 4 ) * 2 + 1 );
4011+ int32_t s1 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 1 ), zero_vec , 16 ));
4012+ sum_int += s1 * ((sc0 >> 4 ) * 2 + 1 );
4013+
4014+ int32_t s2 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 2 ), zero_vec , 16 ));
4015+ sum_int += s2 * ((sc1 & 0x0F ) * 2 + 1 );
4016+
4017+ int32_t s3 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 3 ), zero_vec , 16 ));
4018+ sum_int += s3 * ((sc1 >> 4 ) * 2 + 1 );
4019+
4020+ int32_t s4 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 4 ), zero_vec , 16 ));
4021+ sum_int += s4 * ((sc2 & 0x0F ) * 2 + 1 );
4022+
4023+ int32_t s5 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 5 ), zero_vec , 16 ));
4024+ sum_int += s5 * ((sc2 >> 4 ) * 2 + 1 );
4025+
4026+ int32_t s6 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 6 ), zero_vec , 16 ));
4027+ sum_int += s6 * ((sc3 & 0x0F ) * 2 + 1 );
4028+
4029+ int32_t s7 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 7 ), zero_vec , 16 ));
4030+ sum_int += s7 * ((sc3 >> 4 ) * 2 + 1 );
40204031 }
40214032
4022- sumf += d * sum_int ;
4033+ sumf += d * ( float ) sum_int ;
40234034 }
40244035 * s = 0.125f * sumf ;
40254036}
40264037
4038+
40274039static void ggml_vec_dot_iq2_xs_q8_K_vl512 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
40284040 assert (n % QK_K == 0 );
40294041 assert (nrc == 1 );
@@ -4099,7 +4111,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
40994111 ggml_vec_dot_iq2_xs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
41004112 break ;
41014113 default :
4102- ggml_vec_dot_iq2_xs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4114+ ggml_vec_dot_iq2_xs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
41034115 break ;
41044116 }
41054117#else
@@ -4371,9 +4383,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
43714383 case 128 :
43724384 ggml_vec_dot_iq2_xxs_q8_K_vl128 (n , s , bs , vx , bx , vy , by , nrc );
43734385 break ;
4374- default :
4386+ case 256 :
43754387 ggml_vec_dot_iq2_xxs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
43764388 break ;
4389+ default :
4390+ ggml_vec_dot_iq2_xxs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
4391+ break ;
43774392 }
43784393#else
43794394 ggml_vec_dot_iq2_xxs_q8_K (n , s , bs , vx , bx , vy , by , nrc );
@@ -4665,13 +4680,13 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
46654680#if defined __riscv_v_intrinsic
46664681 switch (__riscv_vlenb () * 8 ) {
46674682 case 128 :
4668- ggml_vec_dot_iq3_s_q8_K_vl128 (n , s , bs , vx , bx , vy , by , nrc );
4683+ ggml_vec_dot_iq3_s_q8_K_vl128 (n , s , bs , vx , bx , vy , by , nrc );
46694684 break ;
46704685 case 256 :
46714686 ggml_vec_dot_iq3_s_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
46724687 break ;
46734688 default :
4674- ggml_vec_dot_iq3_s_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4689+ ggml_vec_dot_iq3_s_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
46754690 break ;
46764691 }
46774692#else
@@ -5058,8 +5073,11 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
50585073 case 256 :
50595074 ggml_vec_dot_iq3_xxs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
50605075 break ;
5076+ case 512 :
5077+ ggml_vec_dot_iq3_xxs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
5078+ break ;
50615079 default :
5062- ggml_vec_dot_iq3_xxs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
5080+ ggml_vec_dot_iq3_xxs_q8_K_vl1024 (n , s , bs , vx , bx , vy , by , nrc );
50635081 break ;
50645082 }
50655083#else
@@ -5939,7 +5957,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
59395957 ggml_vec_dot_tq2_0_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
59405958 break ;
59415959 default :
5942- ggml_vec_dot_tq2_0_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
5960+ ggml_vec_dot_tq2_0_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
59435961 break ;
59445962 }
59455963#else
0 commit comments