2222
2323#define UNUSED GGML_UNUSED
2424
25+ #if defined(__AVX2__ )
26+ #include <immintrin.h>
27+
28+ // horizontally add 8 floats
29+ static inline float hsum_float_8 (const __m256 x ) {
30+ __m128 res = _mm256_extractf128_ps (x , 1 );
31+ res = _mm_add_ps (res , _mm256_castps256_ps128 (x ));
32+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
33+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
34+ return _mm_cvtss_f32 (res );
35+ }
36+
37+ // 32 packed bits -> 32 bytes
38+ static inline __m256i bytes_from_bits_32 (const uint8_t * x ) {
39+ uint32_t x32 ;
40+ memcpy (& x32 , x , sizeof (uint32_t ));
41+ const __m256i shuf_mask = _mm256_set_epi64x (
42+ 0x0303030303030303 , 0x0202020202020202 ,
43+ 0x0101010101010101 , 0x0000000000000000 );
44+ __m256i bytes = _mm256_shuffle_epi8 (_mm256_set1_epi32 (x32 ), shuf_mask );
45+ const __m256i bit_mask = _mm256_set1_epi64x (0x7fbfdfeff7fbfdfe );
46+ bytes = _mm256_or_si256 (bytes , bit_mask );
47+ return _mm256_sub_epi8 (_mm256_setzero_si256 (), _mm256_cmpeq_epi8 (bytes , _mm256_set1_epi64x (-1 )));
48+ }
49+
50+ // multiply int8_t, add results pairwise twice and return as float vector
51+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
52+ const __m256i ax = _mm256_sign_epi8 (x , x );
53+ const __m256i sy = _mm256_sign_epi8 (y , x );
54+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
55+ const __m256i ones = _mm256_set1_epi16 (1 );
56+ const __m256i summed_pairs = _mm256_madd_epi16 (ones , dot );
57+ return _mm256_cvtepi32_ps (summed_pairs );
58+ }
59+ #endif
60+
2561void quantize_row_q1_0 (const float * GGML_RESTRICT x , void * GGML_RESTRICT y , int64_t k ) {
2662 quantize_row_q1_0_ref (x , y , k );
2763}
@@ -134,8 +170,34 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
134170 const block_q1_0 * GGML_RESTRICT x = vx ;
135171 const block_q8_0 * GGML_RESTRICT y = vy ;
136172
137- float sumf = 0.0 ;
173+ float sumf = 0.0f ;
174+
175+ #if defined(__AVX2__ )
176+ // AVX2 path
177+ __m256 acc = _mm256_setzero_ps ();
178+ const __m256i one = _mm256_set1_epi8 (1 );
179+
180+ for (int i = 0 ; i < nb ; ++ i ) {
181+ const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
182+
183+ for (int k = 0 ; k < 4 ; ++ k ) {
184+ const float d1 = GGML_FP16_TO_FP32 (y [i * 4 + k ].d );
185+ const __m256i bits = bytes_from_bits_32 (x [i ].qs + k * 4 );
186+
187+ __m256i qx = bits ;
188+
189+ qx = _mm256_sub_epi8 (_mm256_slli_epi16 (qx , 1 ), one );
190+
191+ const __m256i qy = _mm256_loadu_si256 ((const __m256i * ) y [i * 4 + k ].qs );
192+ const __m256 q = mul_sum_i8_pairs_float (qx , qy );
193+ const __m256 d = _mm256_set1_ps (d0 * d1 );
194+
195+ acc = _mm256_fmadd_ps (d , q , acc );
196+ }
197+ }
138198
199+ sumf = hsum_float_8 (acc );
200+ #else
139201 for (int i = 0 ; i < nb ; i ++ ) {
140202 const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
141203
@@ -160,6 +222,7 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
160222
161223 sumf += d0 * sumi ;
162224 }
225+ #endif
163226
164227 * s = sumf ;
165228}
0 commit comments