Skip to content

Commit 1c569b9

Browse files
committed
ggml-cpu: add Q1_0 AVX2 path
1 parent 0988acc commit 1c569b9

1 file changed

Lines changed: 64 additions & 1 deletion

File tree

ggml/src/ggml-cpu/quants.c

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,42 @@
2222

2323
#define UNUSED GGML_UNUSED
2424

25+
#if defined(__AVX2__)
26+
#include <immintrin.h>
27+
28+
// horizontally add 8 floats
29+
static inline float hsum_float_8(const __m256 x) {
30+
__m128 res = _mm256_extractf128_ps(x, 1);
31+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
32+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
33+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
34+
return _mm_cvtss_f32(res);
35+
}
36+
37+
// 32 packed bits -> 32 bytes
38+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
39+
uint32_t x32;
40+
memcpy(&x32, x, sizeof(uint32_t));
41+
const __m256i shuf_mask = _mm256_set_epi64x(
42+
0x0303030303030303, 0x0202020202020202,
43+
0x0101010101010101, 0x0000000000000000);
44+
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
45+
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
46+
bytes = _mm256_or_si256(bytes, bit_mask);
47+
return _mm256_sub_epi8(_mm256_setzero_si256(), _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)));
48+
}
49+
50+
// multiply int8_t, add results pairwise twice and return as float vector
51+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
52+
const __m256i ax = _mm256_sign_epi8(x, x);
53+
const __m256i sy = _mm256_sign_epi8(y, x);
54+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
55+
const __m256i ones = _mm256_set1_epi16(1);
56+
const __m256i summed_pairs = _mm256_madd_epi16(ones, dot);
57+
return _mm256_cvtepi32_ps(summed_pairs);
58+
}
59+
#endif
60+
2561
void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
2662
quantize_row_q1_0_ref(x, y, k);
2763
}
@@ -134,8 +170,34 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
134170
const block_q1_0 * GGML_RESTRICT x = vx;
135171
const block_q8_0 * GGML_RESTRICT y = vy;
136172

137-
float sumf = 0.0;
173+
float sumf = 0.0f;
174+
175+
#if defined(__AVX2__)
176+
// AVX2 path
177+
__m256 acc = _mm256_setzero_ps();
178+
const __m256i one = _mm256_set1_epi8(1);
179+
180+
for (int i = 0; i < nb; ++i) {
181+
const float d0 = GGML_FP16_TO_FP32(x[i].d);
182+
183+
for (int k = 0; k < 4; ++k) {
184+
const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d);
185+
const __m256i bits = bytes_from_bits_32(x[i].qs + k*4);
186+
187+
__m256i qx = bits;
188+
189+
qx = _mm256_sub_epi8(_mm256_slli_epi16(qx, 1), one);
190+
191+
const __m256i qy = _mm256_loadu_si256((const __m256i *) y[i*4 + k].qs);
192+
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
193+
const __m256 d = _mm256_set1_ps(d0 * d1);
194+
195+
acc = _mm256_fmadd_ps(d, q, acc);
196+
}
197+
}
138198

199+
sumf = hsum_float_8(acc);
200+
#else
139201
for (int i = 0; i < nb; i++) {
140202
const float d0 = GGML_FP16_TO_FP32(x[i].d);
141203

@@ -160,6 +222,7 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
160222

161223
sumf += d0 * sumi;
162224
}
225+
#endif
163226

164227
*s = sumf;
165228
}

0 commit comments

Comments
 (0)