diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c589a213e9d..595ded09f03 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -83,7 +83,6 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633d..ece6b56c574 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -274,6 +274,24 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const } #endif #elif defined(__SSSE3__) +static inline float hsum_float_4(const __m128 x) { + __m128 res = _mm_hadd_ps(x, x); + res = _mm_hadd_ps(res, res); + return _mm_cvtss_f32(res); +} + +static inline __m128i bytes_from_bits_16(const uint8_t * x) { + uint16_t x16; + memcpy(&x16, x, sizeof(uint16_t)); + + const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + __m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask); + const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe); + bytes = _mm_or_si128(bytes, bit_mask); + + return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1)); +} + // horizontally add 4x4 floats static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { __m128 res_0 =_mm_hadd_ps(a, b); @@ -540,6 +558,1152 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + assert((nrc == 2) || (nrc == 1)); + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + + if (nrc == 2) { + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q8_0 * GGML_RESTRICT y0 = vy; + const block_q8_0 * GGML_RESTRICT y1 = (const block_q8_0 *) ((const uint8_t *) vy + by); + + __m256 acc_00 = _mm256_setzero_ps(); + __m256 acc_01 = _mm256_setzero_ps(); + __m256 acc_10 = _mm256_setzero_ps(); + __m256 acc_11 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d00 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d10 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const uint32_t * GGML_RESTRICT qs0 = (const uint32_t *) x0[ib].qs; + const uint32_t * GGML_RESTRICT qs1 = (const uint32_t *) x1[ib].qs; + const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4]; + const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4]; + __m256 acc_block_00 = _mm256_setzero_ps(); + __m256 acc_block_01 = _mm256_setzero_ps(); + __m256 acc_block_10 = _mm256_setzero_ps(); + __m256 acc_block_11 = _mm256_setzero_ps(); + +#define Q1_AVX2_BLOCK_PAIR(K) \ + { \ + const __m256i sm0 = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm1 = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[K]), byte_shuf), bit_masks), zero); \ + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) y0_ptr[K].qs); \ + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) y1_ptr[K].qs); \ + const __m256i sy00 = _mm256_sub_epi8(_mm256_xor_si256(qy0, sm0), sm0); \ + const __m256i sy01 = _mm256_sub_epi8(_mm256_xor_si256(qy1, sm0), sm0); \ + const __m256i sy10 = _mm256_sub_epi8(_mm256_xor_si256(qy0, sm1), sm1); \ + const __m256i sy11 = _mm256_sub_epi8(_mm256_xor_si256(qy1, sm1), sm1); \ + const __m256i s32_00 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy00), ones_16); \ + const __m256i s32_01 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy01), ones_16); \ + const __m256i s32_10 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy10), ones_16); \ + const __m256i s32_11 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy11), ones_16); \ + acc_block_00 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[K].d)), _mm256_cvtepi32_ps(s32_00), acc_block_00); \ + acc_block_01 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[K].d)), _mm256_cvtepi32_ps(s32_01), acc_block_01); \ + acc_block_10 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[K].d)), _mm256_cvtepi32_ps(s32_10), acc_block_10); \ + acc_block_11 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[K].d)), _mm256_cvtepi32_ps(s32_11), acc_block_11); \ + } + Q1_AVX2_BLOCK_PAIR(0) + Q1_AVX2_BLOCK_PAIR(1) + Q1_AVX2_BLOCK_PAIR(2) + Q1_AVX2_BLOCK_PAIR(3) +#undef Q1_AVX2_BLOCK_PAIR + + acc_00 = _mm256_fmadd_ps(_mm256_set1_ps(d00), acc_block_00, acc_00); + acc_01 = _mm256_fmadd_ps(_mm256_set1_ps(d00), acc_block_01, acc_01); + acc_10 = _mm256_fmadd_ps(_mm256_set1_ps(d10), acc_block_10, acc_10); + acc_11 = _mm256_fmadd_ps(_mm256_set1_ps(d10), acc_block_11, acc_11); + } + + s[0] = hsum_float_8(acc_00); + s[1] = hsum_float_8(acc_10); + s[bs] = hsum_float_8(acc_01); + s[bs + 1] = hsum_float_8(acc_11); + return; + } + + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block; + { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32)); + } +#define Q1_AVX2_BLOCK(K) \ + { \ + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \ + const __m256i sm = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); \ + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); \ + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \ + acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \ + } + Q1_AVX2_BLOCK(1) + Q1_AVX2_BLOCK(2) + Q1_AVX2_BLOCK(3) +#undef Q1_AVX2_BLOCK + acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + assert((nrc == 2) || (nrc == 1)); + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + + if (nrc == 2) { + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q8_0 * GGML_RESTRICT y0 = vy; + const block_q8_0 * GGML_RESTRICT y1 = (const block_q8_0 *) ((const uint8_t *) vy + by); + + __m256 acc_00 = _mm256_setzero_ps(); + __m256 acc_01 = _mm256_setzero_ps(); + __m256 acc_10 = _mm256_setzero_ps(); + __m256 acc_11 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d00 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d10 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4]; + const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4]; + __m256 acc_block_00 = _mm256_setzero_ps(); + __m256 acc_block_01 = _mm256_setzero_ps(); + __m256 acc_block_10 = _mm256_setzero_ps(); + __m256 acc_block_11 = _mm256_setzero_ps(); + +#define Q1_AVX_BLOCK_PAIR(K) \ + { \ + const __m256i bit_mask_0 = bytes_from_bits_32(&x0[ib].qs[(K) * 4]); \ + const __m256i bit_mask_1 = bytes_from_bits_32(&x1[ib].qs[(K) * 4]); \ + const __m128i bit_mask_00 = _mm256_castsi256_si128(bit_mask_0); \ + const __m128i bit_mask_01 = _mm256_extractf128_si256(bit_mask_0, 1); \ + const __m128i bit_mask_10 = _mm256_castsi256_si128(bit_mask_1); \ + const __m128i bit_mask_11 = _mm256_extractf128_si256(bit_mask_1, 1); \ + const __m128i qy0_0 = _mm_loadu_si128((const __m128i *) &y0_ptr[(K)].qs[0]); \ + const __m128i qy0_1 = _mm_loadu_si128((const __m128i *) &y0_ptr[(K)].qs[16]); \ + const __m128i qy1_0 = _mm_loadu_si128((const __m128i *) &y1_ptr[(K)].qs[0]); \ + const __m128i qy1_1 = _mm_loadu_si128((const __m128i *) &y1_ptr[(K)].qs[16]); \ + const __m128i sign_mask_00 = _mm_cmpeq_epi8(bit_mask_00, zero); \ + const __m128i sign_mask_01 = _mm_cmpeq_epi8(bit_mask_01, zero); \ + const __m128i sign_mask_10 = _mm_cmpeq_epi8(bit_mask_10, zero); \ + const __m128i sign_mask_11 = _mm_cmpeq_epi8(bit_mask_11, zero); \ + const __m128i sy00_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_00), sign_mask_00); \ + const __m128i sy00_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_01), sign_mask_01); \ + const __m128i sy01_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_00), sign_mask_00); \ + const __m128i sy01_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_01), sign_mask_01); \ + const __m128i sy10_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_10), sign_mask_10); \ + const __m128i sy10_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_11), sign_mask_11); \ + const __m128i sy11_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_10), sign_mask_10); \ + const __m128i sy11_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_11), sign_mask_11); \ + const __m128i sum16_00_0 = _mm_maddubs_epi16(ones_8, sy00_0); \ + const __m128i sum16_00_1 = _mm_maddubs_epi16(ones_8, sy00_1); \ + const __m128i sum16_01_0 = _mm_maddubs_epi16(ones_8, sy01_0); \ + const __m128i sum16_01_1 = _mm_maddubs_epi16(ones_8, sy01_1); \ + const __m128i sum16_10_0 = _mm_maddubs_epi16(ones_8, sy10_0); \ + const __m128i sum16_10_1 = _mm_maddubs_epi16(ones_8, sy10_1); \ + const __m128i sum16_11_0 = _mm_maddubs_epi16(ones_8, sy11_0); \ + const __m128i sum16_11_1 = _mm_maddubs_epi16(ones_8, sy11_1); \ + const __m128i sum32_00_0 = _mm_madd_epi16(sum16_00_0, ones_16); \ + const __m128i sum32_00_1 = _mm_madd_epi16(sum16_00_1, ones_16); \ + const __m128i sum32_01_0 = _mm_madd_epi16(sum16_01_0, ones_16); \ + const __m128i sum32_01_1 = _mm_madd_epi16(sum16_01_1, ones_16); \ + const __m128i sum32_10_0 = _mm_madd_epi16(sum16_10_0, ones_16); \ + const __m128i sum32_10_1 = _mm_madd_epi16(sum16_10_1, ones_16); \ + const __m128i sum32_11_0 = _mm_madd_epi16(sum16_11_0, ones_16); \ + const __m128i sum32_11_1 = _mm_madd_epi16(sum16_11_1, ones_16); \ + const __m256 q00 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_00_1, sum32_00_0)); \ + const __m256 q01 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_01_1, sum32_01_0)); \ + const __m256 q10 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_10_1, sum32_10_0)); \ + const __m256 q11 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_11_1, sum32_11_0)); \ + acc_block_00 = _mm256_add_ps(acc_block_00, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(K)].d)), q00)); \ + acc_block_01 = _mm256_add_ps(acc_block_01, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(K)].d)), q01)); \ + acc_block_10 = _mm256_add_ps(acc_block_10, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(K)].d)), q10)); \ + acc_block_11 = _mm256_add_ps(acc_block_11, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(K)].d)), q11)); \ + } + Q1_AVX_BLOCK_PAIR(0) + Q1_AVX_BLOCK_PAIR(1) + Q1_AVX_BLOCK_PAIR(2) + Q1_AVX_BLOCK_PAIR(3) +#undef Q1_AVX_BLOCK_PAIR + + acc_00 = _mm256_add_ps(acc_00, _mm256_mul_ps(_mm256_set1_ps(d00), acc_block_00)); + acc_01 = _mm256_add_ps(acc_01, _mm256_mul_ps(_mm256_set1_ps(d00), acc_block_01)); + acc_10 = _mm256_add_ps(acc_10, _mm256_mul_ps(_mm256_set1_ps(d10), acc_block_10)); + acc_11 = _mm256_add_ps(acc_11, _mm256_mul_ps(_mm256_set1_ps(d10), acc_block_11)); + } + + s[0] = hsum_float_8(acc_00); + s[1] = hsum_float_8(acc_10); + s[bs] = hsum_float_8(acc_01); + s[bs + 1] = hsum_float_8(acc_11); + return; + } + + assert(nrc == 1); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block = _mm256_setzero_ps(); +#define Q1_AVX_BLOCK(K) \ + { \ + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); \ + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); \ + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); \ + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); \ + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); \ + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); \ + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); \ + acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); \ + } + Q1_AVX_BLOCK(0) + Q1_AVX_BLOCK(1) + Q1_AVX_BLOCK(2) + Q1_AVX_BLOCK(3) +#undef Q1_AVX_BLOCK + + acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block)); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + assert((nrc == 2) || (nrc == 1)); + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + + if (nrc == 2) { + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q8_0 * GGML_RESTRICT y0 = vy; + const block_q8_0 * GGML_RESTRICT y1 = (const block_q8_0 *) ((const uint8_t *) vy + by); + + __m128 acc_00 = _mm_setzero_ps(); + __m128 acc_01 = _mm_setzero_ps(); + __m128 acc_10 = _mm_setzero_ps(); + __m128 acc_11 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d00 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0[ib].d)); + const __m128 d10 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1[ib].d)); + const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4]; + const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4]; + +#define Q1_SSSE3_BLOCK_PAIR(QS_OFF, Y_IDX) \ + { \ + const __m128i bit_mask_00 = bytes_from_bits_16(&x0[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_01 = bytes_from_bits_16(&x0[ib].qs[(QS_OFF) + 2]); \ + const __m128i bit_mask_10 = bytes_from_bits_16(&x1[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_11 = bytes_from_bits_16(&x1[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy0_0 = _mm_loadu_si128((const __m128i *) &y0_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy0_1 = _mm_loadu_si128((const __m128i *) &y0_ptr[(Y_IDX)].qs[16]); \ + const __m128i qy1_0 = _mm_loadu_si128((const __m128i *) &y1_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy1_1 = _mm_loadu_si128((const __m128i *) &y1_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_00 = _mm_cmpeq_epi8(bit_mask_00, zero); \ + const __m128i sign_mask_01 = _mm_cmpeq_epi8(bit_mask_01, zero); \ + const __m128i sign_mask_10 = _mm_cmpeq_epi8(bit_mask_10, zero); \ + const __m128i sign_mask_11 = _mm_cmpeq_epi8(bit_mask_11, zero); \ + const __m128i sy00_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_00), sign_mask_00); \ + const __m128i sy00_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_01), sign_mask_01); \ + const __m128i sy01_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_00), sign_mask_00); \ + const __m128i sy01_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_01), sign_mask_01); \ + const __m128i sy10_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_10), sign_mask_10); \ + const __m128i sy10_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_11), sign_mask_11); \ + const __m128i sy11_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_10), sign_mask_10); \ + const __m128i sy11_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_11), sign_mask_11); \ + const __m128i sum_00_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy00_0), ones_16); \ + const __m128i sum_00_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy00_1), ones_16); \ + const __m128i sum_01_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy01_0), ones_16); \ + const __m128i sum_01_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy01_1), ones_16); \ + const __m128i sum_10_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy10_0), ones_16); \ + const __m128i sum_10_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy10_1), ones_16); \ + const __m128i sum_11_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy11_0), ones_16); \ + const __m128i sum_11_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy11_1), ones_16); \ + const __m128 q00 = _mm_cvtepi32_ps(_mm_add_epi32(sum_00_0, sum_00_1)); \ + const __m128 q01 = _mm_cvtepi32_ps(_mm_add_epi32(sum_01_0, sum_01_1)); \ + const __m128 q10 = _mm_cvtepi32_ps(_mm_add_epi32(sum_10_0, sum_10_1)); \ + const __m128 q11 = _mm_cvtepi32_ps(_mm_add_epi32(sum_11_0, sum_11_1)); \ + acc_00 = _mm_add_ps(acc_00, _mm_mul_ps(_mm_mul_ps(d00, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(Y_IDX)].d))), q00)); \ + acc_01 = _mm_add_ps(acc_01, _mm_mul_ps(_mm_mul_ps(d00, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(Y_IDX)].d))), q01)); \ + acc_10 = _mm_add_ps(acc_10, _mm_mul_ps(_mm_mul_ps(d10, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(Y_IDX)].d))), q10)); \ + acc_11 = _mm_add_ps(acc_11, _mm_mul_ps(_mm_mul_ps(d10, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(Y_IDX)].d))), q11)); \ + } + Q1_SSSE3_BLOCK_PAIR(0, 0) + Q1_SSSE3_BLOCK_PAIR(4, 1) + Q1_SSSE3_BLOCK_PAIR(8, 2) + Q1_SSSE3_BLOCK_PAIR(12, 3) +#undef Q1_SSSE3_BLOCK_PAIR + } + + s[0] = hsum_float_4(acc_00); + s[1] = hsum_float_4(acc_10); + s[bs] = hsum_float_4(acc_01); + s[bs + 1] = hsum_float_4(acc_11); + return; + } + + assert(nrc == 1); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d)); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + +#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \ + { \ + const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \ + const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \ + const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \ + (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \ + } + Q1_SSSE3_BLOCK(0, 0, acc_0) + Q1_SSSE3_BLOCK(4, 1, acc_1) + Q1_SSSE3_BLOCK(8, 2, acc_2) + Q1_SSSE3_BLOCK(12, 3, acc_3) +#undef Q1_SSSE3_BLOCK + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + UNUSED(nrc); + assert(nrc == 1); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + float sumf = 0.0f; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + float sumi = 0.0f; + + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + int sumi_block = 0; + + const uint8_t * GGML_RESTRICT bits = &x[ib].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q1_0_q8_0_4x1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q1_0 * GGML_RESTRICT x2 = (const block_q1_0 *) ((const uint8_t *) vx + 2 * bx); + const block_q1_0 * GGML_RESTRICT x3 = (const block_q1_0 *) ((const uint8_t *) vx + 3 * bx); + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const float d2 = GGML_CPU_FP16_TO_FP32(x2[ib].d); + const float d3 = GGML_CPU_FP16_TO_FP32(x3[ib].d); + const uint32_t * GGML_RESTRICT qs0 = (const uint32_t *) x0[ib].qs; + const uint32_t * GGML_RESTRICT qs1 = (const uint32_t *) x1[ib].qs; + const uint32_t * GGML_RESTRICT qs2 = (const uint32_t *) x2[ib].qs; + const uint32_t * GGML_RESTRICT qs3 = (const uint32_t *) x3[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block0 = _mm256_setzero_ps(); + __m256 acc_block1 = _mm256_setzero_ps(); + __m256 acc_block2 = _mm256_setzero_ps(); + __m256 acc_block3 = _mm256_setzero_ps(); + +#define Q1_AVX2_BLOCK_4X1(K) \ + { \ + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \ + const __m256i sm0 = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm1 = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm2 = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs2[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm3 = _mm256_cmpeq_epi8( \ + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs3[K]), byte_shuf), bit_masks), zero); \ + const __m256i sy0 = _mm256_sub_epi8(_mm256_xor_si256(qy, sm0), sm0); \ + const __m256i sy1 = _mm256_sub_epi8(_mm256_xor_si256(qy, sm1), sm1); \ + const __m256i sy2 = _mm256_sub_epi8(_mm256_xor_si256(qy, sm2), sm2); \ + const __m256i sy3 = _mm256_sub_epi8(_mm256_xor_si256(qy, sm3), sm3); \ + const __m256i s320 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16); \ + const __m256i s321 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16); \ + const __m256i s322 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16); \ + const __m256i s323 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16); \ + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)); \ + acc_block0 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s320), acc_block0); \ + acc_block1 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s321), acc_block1); \ + acc_block2 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s322), acc_block2); \ + acc_block3 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s323), acc_block3); \ + } + Q1_AVX2_BLOCK_4X1(0) + Q1_AVX2_BLOCK_4X1(1) + Q1_AVX2_BLOCK_4X1(2) + Q1_AVX2_BLOCK_4X1(3) +#undef Q1_AVX2_BLOCK_4X1 + + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block0, acc0); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(d1), acc_block1, acc1); + acc2 = _mm256_fmadd_ps(_mm256_set1_ps(d2), acc_block2, acc2); + acc3 = _mm256_fmadd_ps(_mm256_set1_ps(d3), acc_block3, acc3); + } + + s[0] = hsum_float_8(acc0); + s[1] = hsum_float_8(acc1); + s[2] = hsum_float_8(acc2); + s[3] = hsum_float_8(acc3); + return; +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const float d2 = GGML_CPU_FP16_TO_FP32(x2[ib].d); + const float d3 = GGML_CPU_FP16_TO_FP32(x3[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block0 = _mm256_setzero_ps(); + __m256 acc_block1 = _mm256_setzero_ps(); + __m256 acc_block2 = _mm256_setzero_ps(); + __m256 acc_block3 = _mm256_setzero_ps(); + +#define Q1_AVX_BLOCK_4X1(K) \ + { \ + const __m256i bit_mask_0 = bytes_from_bits_32(&x0[ib].qs[(K) * 4]); \ + const __m256i bit_mask_1 = bytes_from_bits_32(&x1[ib].qs[(K) * 4]); \ + const __m256i bit_mask_2 = bytes_from_bits_32(&x2[ib].qs[(K) * 4]); \ + const __m256i bit_mask_3 = bytes_from_bits_32(&x3[ib].qs[(K) * 4]); \ + const __m128i bit_mask_00 = _mm256_castsi256_si128(bit_mask_0); \ + const __m128i bit_mask_01 = _mm256_extractf128_si256(bit_mask_0, 1); \ + const __m128i bit_mask_10 = _mm256_castsi256_si128(bit_mask_1); \ + const __m128i bit_mask_11 = _mm256_extractf128_si256(bit_mask_1, 1); \ + const __m128i bit_mask_20 = _mm256_castsi256_si128(bit_mask_2); \ + const __m128i bit_mask_21 = _mm256_extractf128_si256(bit_mask_2, 1); \ + const __m128i bit_mask_30 = _mm256_castsi256_si128(bit_mask_3); \ + const __m128i bit_mask_31 = _mm256_extractf128_si256(bit_mask_3, 1); \ + const __m128i qy0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \ + const __m128i qy1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \ + const __m128i sign_mask_00 = _mm_cmpeq_epi8(bit_mask_00, zero); \ + const __m128i sign_mask_01 = _mm_cmpeq_epi8(bit_mask_01, zero); \ + const __m128i sign_mask_10 = _mm_cmpeq_epi8(bit_mask_10, zero); \ + const __m128i sign_mask_11 = _mm_cmpeq_epi8(bit_mask_11, zero); \ + const __m128i sign_mask_20 = _mm_cmpeq_epi8(bit_mask_20, zero); \ + const __m128i sign_mask_21 = _mm_cmpeq_epi8(bit_mask_21, zero); \ + const __m128i sign_mask_30 = _mm_cmpeq_epi8(bit_mask_30, zero); \ + const __m128i sign_mask_31 = _mm_cmpeq_epi8(bit_mask_31, zero); \ + const __m128i sy00 = _mm_sub_epi8(_mm_xor_si128(qy0, sign_mask_00), sign_mask_00); \ + const __m128i sy01 = _mm_sub_epi8(_mm_xor_si128(qy1, sign_mask_01), sign_mask_01); \ + const __m128i sy10 = _mm_sub_epi8(_mm_xor_si128(qy0, sign_mask_10), sign_mask_10); \ + const __m128i sy11 = _mm_sub_epi8(_mm_xor_si128(qy1, sign_mask_11), sign_mask_11); \ + const __m128i sy20 = _mm_sub_epi8(_mm_xor_si128(qy0, sign_mask_20), sign_mask_20); \ + const __m128i sy21 = _mm_sub_epi8(_mm_xor_si128(qy1, sign_mask_21), sign_mask_21); \ + const __m128i sy30 = _mm_sub_epi8(_mm_xor_si128(qy0, sign_mask_30), sign_mask_30); \ + const __m128i sy31 = _mm_sub_epi8(_mm_xor_si128(qy1, sign_mask_31), sign_mask_31); \ + const __m128i sum32_00 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy00), ones_16); \ + const __m128i sum32_01 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy01), ones_16); \ + const __m128i sum32_10 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy10), ones_16); \ + const __m128i sum32_11 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy11), ones_16); \ + const __m128i sum32_20 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy20), ones_16); \ + const __m128i sum32_21 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy21), ones_16); \ + const __m128i sum32_30 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy30), ones_16); \ + const __m128i sum32_31 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy31), ones_16); \ + const __m256 q0 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_01, sum32_00)); \ + const __m256 q1 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_11, sum32_10)); \ + const __m256 q2 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_21, sum32_20)); \ + const __m256 q3 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_31, sum32_30)); \ + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)); \ + acc_block0 = _mm256_add_ps(acc_block0, _mm256_mul_ps(dy, q0)); \ + acc_block1 = _mm256_add_ps(acc_block1, _mm256_mul_ps(dy, q1)); \ + acc_block2 = _mm256_add_ps(acc_block2, _mm256_mul_ps(dy, q2)); \ + acc_block3 = _mm256_add_ps(acc_block3, _mm256_mul_ps(dy, q3)); \ + } + Q1_AVX_BLOCK_4X1(0) + Q1_AVX_BLOCK_4X1(1) + Q1_AVX_BLOCK_4X1(2) + Q1_AVX_BLOCK_4X1(3) +#undef Q1_AVX_BLOCK_4X1 + + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block0)); + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_set1_ps(d1), acc_block1)); + acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_set1_ps(d2), acc_block2)); + acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_set1_ps(d3), acc_block3)); + } + + s[0] = hsum_float_8(acc0); + s[1] = hsum_float_8(acc1); + s[2] = hsum_float_8(acc2); + s[3] = hsum_float_8(acc3); + return; +#endif + + for (int i = 0; i < 4; ++i) { + ggml_vec_dot_q1_0_q8_0(n, s + i, 0, (const uint8_t *) vx + i * bx, 0, vy, 0, 1); + } +} + +void ggml_vec_dot_q1_0_q8_0_2x1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const uint32_t * GGML_RESTRICT qs0 = (const uint32_t *) x0[ib].qs; + const uint32_t * GGML_RESTRICT qs1 = (const uint32_t *) x1[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block0 = _mm256_setzero_ps(); + __m256 acc_block1 = _mm256_setzero_ps(); + +#define Q1_AVX2_BLOCK_2X1(K) \ + { \ + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \ + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[K]), byte_shuf), bit_masks), zero); \ + const __m256i sy0 = _mm256_sub_epi8(_mm256_xor_si256(qy, sm0), sm0); \ + const __m256i sy1 = _mm256_sub_epi8(_mm256_xor_si256(qy, sm1), sm1); \ + const __m256i s320 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16); \ + const __m256i s321 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16); \ + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)); \ + acc_block0 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s320), acc_block0); \ + acc_block1 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s321), acc_block1); \ + } + Q1_AVX2_BLOCK_2X1(0) + Q1_AVX2_BLOCK_2X1(1) + Q1_AVX2_BLOCK_2X1(2) + Q1_AVX2_BLOCK_2X1(3) +#undef Q1_AVX2_BLOCK_2X1 + + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block0, acc0); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(d1), acc_block1, acc1); + } + + s[0] = hsum_float_8(acc0); + s[1] = hsum_float_8(acc1); + return; +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block0 = _mm256_setzero_ps(); + __m256 acc_block1 = _mm256_setzero_ps(); + +#define Q1_AVX_BLOCK_2X1(K) \ + { \ + const __m256i bit_mask_0 = bytes_from_bits_32(&x0[ib].qs[(K) * 4]); \ + const __m256i bit_mask_1 = bytes_from_bits_32(&x1[ib].qs[(K) * 4]); \ + const __m128i bit_mask_00 = _mm256_castsi256_si128(bit_mask_0); \ + const __m128i bit_mask_01 = _mm256_extractf128_si256(bit_mask_0, 1); \ + const __m128i bit_mask_10 = _mm256_castsi256_si128(bit_mask_1); \ + const __m128i bit_mask_11 = _mm256_extractf128_si256(bit_mask_1, 1); \ + const __m128i qy0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \ + const __m128i qy1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \ + const __m128i sign_mask_00 = _mm_cmpeq_epi8(bit_mask_00, zero); \ + const __m128i sign_mask_01 = _mm_cmpeq_epi8(bit_mask_01, zero); \ + const __m128i sign_mask_10 = _mm_cmpeq_epi8(bit_mask_10, zero); \ + const __m128i sign_mask_11 = _mm_cmpeq_epi8(bit_mask_11, zero); \ + const __m128i sy00 = _mm_sub_epi8(_mm_xor_si128(qy0, sign_mask_00), sign_mask_00); \ + const __m128i sy01 = _mm_sub_epi8(_mm_xor_si128(qy1, sign_mask_01), sign_mask_01); \ + const __m128i sy10 = _mm_sub_epi8(_mm_xor_si128(qy0, sign_mask_10), sign_mask_10); \ + const __m128i sy11 = _mm_sub_epi8(_mm_xor_si128(qy1, sign_mask_11), sign_mask_11); \ + const __m128i sum32_00 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy00), ones_16); \ + const __m128i sum32_01 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy01), ones_16); \ + const __m128i sum32_10 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy10), ones_16); \ + const __m128i sum32_11 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy11), ones_16); \ + const __m256 q0 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_01, sum32_00)); \ + const __m256 q1 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_11, sum32_10)); \ + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)); \ + acc_block0 = _mm256_add_ps(acc_block0, _mm256_mul_ps(dy, q0)); \ + acc_block1 = _mm256_add_ps(acc_block1, _mm256_mul_ps(dy, q1)); \ + } + Q1_AVX_BLOCK_2X1(0) + Q1_AVX_BLOCK_2X1(1) + Q1_AVX_BLOCK_2X1(2) + Q1_AVX_BLOCK_2X1(3) +#undef Q1_AVX_BLOCK_2X1 + + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block0)); + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_set1_ps(d1), acc_block1)); + } + + s[0] = hsum_float_8(acc0); + s[1] = hsum_float_8(acc1); + return; +#endif + + ggml_vec_dot_q1_0_q8_0(n, s, 0, vx, 0, vy, 0, 1); + ggml_vec_dot_q1_0_q8_0(n, s + 1, 0, (const uint8_t *) vx + bx, 0, vy, 0, 1); +} + +void ggml_vec_dot_q1_0_q8_0_4x2(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q1_0 * GGML_RESTRICT x2 = (const block_q1_0 *) ((const uint8_t *) vx + 2 * bx); + const block_q1_0 * GGML_RESTRICT x3 = (const block_q1_0 *) ((const uint8_t *) vx + 3 * bx); + const block_q8_0 * GGML_RESTRICT y0 = vy; + const block_q8_0 * GGML_RESTRICT y1 = (const block_q8_0 *) ((const uint8_t *) vy + by); + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + + __m256 acc00 = _mm256_setzero_ps(); + __m256 acc01 = _mm256_setzero_ps(); + __m256 acc10 = _mm256_setzero_ps(); + __m256 acc11 = _mm256_setzero_ps(); + __m256 acc20 = _mm256_setzero_ps(); + __m256 acc21 = _mm256_setzero_ps(); + __m256 acc30 = _mm256_setzero_ps(); + __m256 acc31 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const float d2 = GGML_CPU_FP16_TO_FP32(x2[ib].d); + const float d3 = GGML_CPU_FP16_TO_FP32(x3[ib].d); + const uint32_t * GGML_RESTRICT qs0 = (const uint32_t *) x0[ib].qs; + const uint32_t * GGML_RESTRICT qs1 = (const uint32_t *) x1[ib].qs; + const uint32_t * GGML_RESTRICT qs2 = (const uint32_t *) x2[ib].qs; + const uint32_t * GGML_RESTRICT qs3 = (const uint32_t *) x3[ib].qs; + const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4]; + const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4]; + + __m256 accb00, accb01, accb10, accb11, accb20, accb21, accb30, accb31; + + { + const __m256i qy0v = _mm256_loadu_si256((const __m256i *) y0_ptr[0].qs); + const __m256i qy1v = _mm256_loadu_si256((const __m256i *) y1_ptr[0].qs); + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[0]), byte_shuf), bit_masks), zero); + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[0]), byte_shuf), bit_masks), zero); + const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs2[0]), byte_shuf), bit_masks), zero); + const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs3[0]), byte_shuf), bit_masks), zero); + const __m256i sy00 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm0), sm0); + const __m256i sy01 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm0), sm0); + const __m256i sy10 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm1), sm1); + const __m256i sy11 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm1), sm1); + const __m256i sy20 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm2), sm2); + const __m256i sy21 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm2), sm2); + const __m256i sy30 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm3), sm3); + const __m256i sy31 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm3), sm3); + const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[0].d)); + const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[0].d)); + accb00 = _mm256_mul_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy00), ones_16))); + accb01 = _mm256_mul_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy01), ones_16))); + accb10 = _mm256_mul_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy10), ones_16))); + accb11 = _mm256_mul_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy11), ones_16))); + accb20 = _mm256_mul_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy20), ones_16))); + accb21 = _mm256_mul_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy21), ones_16))); + accb30 = _mm256_mul_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy30), ones_16))); + accb31 = _mm256_mul_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy31), ones_16))); + } + +#define Q1_AVX2_BLOCK_4X2(K) \ + { \ + const __m256i qy0v = _mm256_loadu_si256((const __m256i *) y0_ptr[K].qs); \ + const __m256i qy1v = _mm256_loadu_si256((const __m256i *) y1_ptr[K].qs); \ + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs2[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs3[K]), byte_shuf), bit_masks), zero); \ + const __m256i sy00 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm0), sm0); \ + const __m256i sy01 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm0), sm0); \ + const __m256i sy10 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm1), sm1); \ + const __m256i sy11 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm1), sm1); \ + const __m256i sy20 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm2), sm2); \ + const __m256i sy21 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm2), sm2); \ + const __m256i sy30 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm3), sm3); \ + const __m256i sy31 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm3), sm3); \ + const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[K].d)); \ + const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[K].d)); \ + accb00 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy00), ones_16)), accb00); \ + accb01 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy01), ones_16)), accb01); \ + accb10 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy10), ones_16)), accb10); \ + accb11 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy11), ones_16)), accb11); \ + accb20 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy20), ones_16)), accb20); \ + accb21 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy21), ones_16)), accb21); \ + accb30 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy30), ones_16)), accb30); \ + accb31 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy31), ones_16)), accb31); \ + } + Q1_AVX2_BLOCK_4X2(1) + Q1_AVX2_BLOCK_4X2(2) + Q1_AVX2_BLOCK_4X2(3) +#undef Q1_AVX2_BLOCK_4X2 + + acc00 = _mm256_fmadd_ps(_mm256_set1_ps(d0), accb00, acc00); + acc01 = _mm256_fmadd_ps(_mm256_set1_ps(d0), accb01, acc01); + acc10 = _mm256_fmadd_ps(_mm256_set1_ps(d1), accb10, acc10); + acc11 = _mm256_fmadd_ps(_mm256_set1_ps(d1), accb11, acc11); + acc20 = _mm256_fmadd_ps(_mm256_set1_ps(d2), accb20, acc20); + acc21 = _mm256_fmadd_ps(_mm256_set1_ps(d2), accb21, acc21); + acc30 = _mm256_fmadd_ps(_mm256_set1_ps(d3), accb30, acc30); + acc31 = _mm256_fmadd_ps(_mm256_set1_ps(d3), accb31, acc31); + } + + s[0] = hsum_float_8(acc00); + s[1] = hsum_float_8(acc10); + s[2] = hsum_float_8(acc20); + s[3] = hsum_float_8(acc30); + s[bs + 0] = hsum_float_8(acc01); + s[bs + 1] = hsum_float_8(acc11); + s[bs + 2] = hsum_float_8(acc21); + s[bs + 3] = hsum_float_8(acc31); + return; +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + + __m256 acc00 = _mm256_setzero_ps(); + __m256 acc01 = _mm256_setzero_ps(); + __m256 acc10 = _mm256_setzero_ps(); + __m256 acc11 = _mm256_setzero_ps(); + __m256 acc20 = _mm256_setzero_ps(); + __m256 acc21 = _mm256_setzero_ps(); + __m256 acc30 = _mm256_setzero_ps(); + __m256 acc31 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const float d2 = GGML_CPU_FP16_TO_FP32(x2[ib].d); + const float d3 = GGML_CPU_FP16_TO_FP32(x3[ib].d); + const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4]; + const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4]; + __m256 acc_block_00 = _mm256_setzero_ps(); + __m256 acc_block_01 = _mm256_setzero_ps(); + __m256 acc_block_10 = _mm256_setzero_ps(); + __m256 acc_block_11 = _mm256_setzero_ps(); + __m256 acc_block_20 = _mm256_setzero_ps(); + __m256 acc_block_21 = _mm256_setzero_ps(); + __m256 acc_block_30 = _mm256_setzero_ps(); + __m256 acc_block_31 = _mm256_setzero_ps(); + +#define Q1_AVX_BLOCK_4X2(K) \ + { \ + const __m256i bit_mask_0 = bytes_from_bits_32(&x0[ib].qs[(K) * 4]); \ + const __m256i bit_mask_1 = bytes_from_bits_32(&x1[ib].qs[(K) * 4]); \ + const __m256i bit_mask_2 = bytes_from_bits_32(&x2[ib].qs[(K) * 4]); \ + const __m256i bit_mask_3 = bytes_from_bits_32(&x3[ib].qs[(K) * 4]); \ + const __m128i bit_mask_00 = _mm256_castsi256_si128(bit_mask_0); \ + const __m128i bit_mask_01 = _mm256_extractf128_si256(bit_mask_0, 1); \ + const __m128i bit_mask_10 = _mm256_castsi256_si128(bit_mask_1); \ + const __m128i bit_mask_11 = _mm256_extractf128_si256(bit_mask_1, 1); \ + const __m128i bit_mask_20 = _mm256_castsi256_si128(bit_mask_2); \ + const __m128i bit_mask_21 = _mm256_extractf128_si256(bit_mask_2, 1); \ + const __m128i bit_mask_30 = _mm256_castsi256_si128(bit_mask_3); \ + const __m128i bit_mask_31 = _mm256_extractf128_si256(bit_mask_3, 1); \ + const __m128i qy0_0 = _mm_loadu_si128((const __m128i *) &y0_ptr[(K)].qs[0]); \ + const __m128i qy0_1 = _mm_loadu_si128((const __m128i *) &y0_ptr[(K)].qs[16]); \ + const __m128i qy1_0 = _mm_loadu_si128((const __m128i *) &y1_ptr[(K)].qs[0]); \ + const __m128i qy1_1 = _mm_loadu_si128((const __m128i *) &y1_ptr[(K)].qs[16]); \ + const __m128i sign_mask_00 = _mm_cmpeq_epi8(bit_mask_00, zero); \ + const __m128i sign_mask_01 = _mm_cmpeq_epi8(bit_mask_01, zero); \ + const __m128i sign_mask_10 = _mm_cmpeq_epi8(bit_mask_10, zero); \ + const __m128i sign_mask_11 = _mm_cmpeq_epi8(bit_mask_11, zero); \ + const __m128i sign_mask_20 = _mm_cmpeq_epi8(bit_mask_20, zero); \ + const __m128i sign_mask_21 = _mm_cmpeq_epi8(bit_mask_21, zero); \ + const __m128i sign_mask_30 = _mm_cmpeq_epi8(bit_mask_30, zero); \ + const __m128i sign_mask_31 = _mm_cmpeq_epi8(bit_mask_31, zero); \ + const __m128i sy00_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_00), sign_mask_00); \ + const __m128i sy00_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_01), sign_mask_01); \ + const __m128i sy01_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_00), sign_mask_00); \ + const __m128i sy01_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_01), sign_mask_01); \ + const __m128i sy10_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_10), sign_mask_10); \ + const __m128i sy10_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_11), sign_mask_11); \ + const __m128i sy11_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_10), sign_mask_10); \ + const __m128i sy11_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_11), sign_mask_11); \ + const __m128i sy20_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_20), sign_mask_20); \ + const __m128i sy20_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_21), sign_mask_21); \ + const __m128i sy21_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_20), sign_mask_20); \ + const __m128i sy21_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_21), sign_mask_21); \ + const __m128i sy30_0 = _mm_sub_epi8(_mm_xor_si128(qy0_0, sign_mask_30), sign_mask_30); \ + const __m128i sy30_1 = _mm_sub_epi8(_mm_xor_si128(qy0_1, sign_mask_31), sign_mask_31); \ + const __m128i sy31_0 = _mm_sub_epi8(_mm_xor_si128(qy1_0, sign_mask_30), sign_mask_30); \ + const __m128i sy31_1 = _mm_sub_epi8(_mm_xor_si128(qy1_1, sign_mask_31), sign_mask_31); \ + const __m128i sum32_00_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy00_0), ones_16); \ + const __m128i sum32_00_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy00_1), ones_16); \ + const __m128i sum32_01_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy01_0), ones_16); \ + const __m128i sum32_01_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy01_1), ones_16); \ + const __m128i sum32_10_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy10_0), ones_16); \ + const __m128i sum32_10_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy10_1), ones_16); \ + const __m128i sum32_11_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy11_0), ones_16); \ + const __m128i sum32_11_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy11_1), ones_16); \ + const __m128i sum32_20_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy20_0), ones_16); \ + const __m128i sum32_20_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy20_1), ones_16); \ + const __m128i sum32_21_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy21_0), ones_16); \ + const __m128i sum32_21_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy21_1), ones_16); \ + const __m128i sum32_30_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy30_0), ones_16); \ + const __m128i sum32_30_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy30_1), ones_16); \ + const __m128i sum32_31_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy31_0), ones_16); \ + const __m128i sum32_31_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy31_1), ones_16); \ + const __m256 q00 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_00_1, sum32_00_0)); \ + const __m256 q01 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_01_1, sum32_01_0)); \ + const __m256 q10 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_10_1, sum32_10_0)); \ + const __m256 q11 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_11_1, sum32_11_0)); \ + const __m256 q20 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_20_1, sum32_20_0)); \ + const __m256 q21 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_21_1, sum32_21_0)); \ + const __m256 q30 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_30_1, sum32_30_0)); \ + const __m256 q31 = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_31_1, sum32_31_0)); \ + acc_block_00 = _mm256_add_ps(acc_block_00, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(K)].d)), q00)); \ + acc_block_01 = _mm256_add_ps(acc_block_01, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(K)].d)), q01)); \ + acc_block_10 = _mm256_add_ps(acc_block_10, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(K)].d)), q10)); \ + acc_block_11 = _mm256_add_ps(acc_block_11, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(K)].d)), q11)); \ + acc_block_20 = _mm256_add_ps(acc_block_20, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(K)].d)), q20)); \ + acc_block_21 = _mm256_add_ps(acc_block_21, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(K)].d)), q21)); \ + acc_block_30 = _mm256_add_ps(acc_block_30, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[(K)].d)), q30)); \ + acc_block_31 = _mm256_add_ps(acc_block_31, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[(K)].d)), q31)); \ + } + Q1_AVX_BLOCK_4X2(0) + Q1_AVX_BLOCK_4X2(1) + Q1_AVX_BLOCK_4X2(2) + Q1_AVX_BLOCK_4X2(3) +#undef Q1_AVX_BLOCK_4X2 + + acc00 = _mm256_add_ps(acc00, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block_00)); + acc01 = _mm256_add_ps(acc01, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block_01)); + acc10 = _mm256_add_ps(acc10, _mm256_mul_ps(_mm256_set1_ps(d1), acc_block_10)); + acc11 = _mm256_add_ps(acc11, _mm256_mul_ps(_mm256_set1_ps(d1), acc_block_11)); + acc20 = _mm256_add_ps(acc20, _mm256_mul_ps(_mm256_set1_ps(d2), acc_block_20)); + acc21 = _mm256_add_ps(acc21, _mm256_mul_ps(_mm256_set1_ps(d2), acc_block_21)); + acc30 = _mm256_add_ps(acc30, _mm256_mul_ps(_mm256_set1_ps(d3), acc_block_30)); + acc31 = _mm256_add_ps(acc31, _mm256_mul_ps(_mm256_set1_ps(d3), acc_block_31)); + } + + s[0] = hsum_float_8(acc00); + s[1] = hsum_float_8(acc10); + s[2] = hsum_float_8(acc20); + s[3] = hsum_float_8(acc30); + s[bs + 0] = hsum_float_8(acc01); + s[bs + 1] = hsum_float_8(acc11); + s[bs + 2] = hsum_float_8(acc21); + s[bs + 3] = hsum_float_8(acc31); + return; +#endif + + ggml_vec_dot_q1_0_q8_0_4x1(n, s, vx, bx, vy); + ggml_vec_dot_q1_0_q8_0_4x1(n, s + bs, vx, bx, (const uint8_t *) vy + by); +} + +void ggml_vec_dot_q1_0_q8_0_4x4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by) { +#if defined(__AVX512F__) && defined(__AVX2__) + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x0 = vx; + const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx); + const block_q1_0 * GGML_RESTRICT x2 = (const block_q1_0 *) ((const uint8_t *) vx + 2 * bx); + const block_q1_0 * GGML_RESTRICT x3 = (const block_q1_0 *) ((const uint8_t *) vx + 3 * bx); + const block_q8_0 * GGML_RESTRICT y0 = vy; + const block_q8_0 * GGML_RESTRICT y1 = (const block_q8_0 *) ((const uint8_t *) vy + by); + const block_q8_0 * GGML_RESTRICT y2 = (const block_q8_0 *) ((const uint8_t *) vy + 2 * by); + const block_q8_0 * GGML_RESTRICT y3 = (const block_q8_0 *) ((const uint8_t *) vy + 3 * by); + + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + + __m256 acc00 = _mm256_setzero_ps(); + __m256 acc01 = _mm256_setzero_ps(); + __m256 acc02 = _mm256_setzero_ps(); + __m256 acc03 = _mm256_setzero_ps(); + __m256 acc10 = _mm256_setzero_ps(); + __m256 acc11 = _mm256_setzero_ps(); + __m256 acc12 = _mm256_setzero_ps(); + __m256 acc13 = _mm256_setzero_ps(); + __m256 acc20 = _mm256_setzero_ps(); + __m256 acc21 = _mm256_setzero_ps(); + __m256 acc22 = _mm256_setzero_ps(); + __m256 acc23 = _mm256_setzero_ps(); + __m256 acc30 = _mm256_setzero_ps(); + __m256 acc31 = _mm256_setzero_ps(); + __m256 acc32 = _mm256_setzero_ps(); + __m256 acc33 = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x0[ib].d); + const float d1 = GGML_CPU_FP16_TO_FP32(x1[ib].d); + const float d2 = GGML_CPU_FP16_TO_FP32(x2[ib].d); + const float d3 = GGML_CPU_FP16_TO_FP32(x3[ib].d); + const uint32_t * GGML_RESTRICT qs0 = (const uint32_t *) x0[ib].qs; + const uint32_t * GGML_RESTRICT qs1 = (const uint32_t *) x1[ib].qs; + const uint32_t * GGML_RESTRICT qs2 = (const uint32_t *) x2[ib].qs; + const uint32_t * GGML_RESTRICT qs3 = (const uint32_t *) x3[ib].qs; + const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4]; + const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4]; + const block_q8_0 * GGML_RESTRICT y2_ptr = &y2[ib * 4]; + const block_q8_0 * GGML_RESTRICT y3_ptr = &y3[ib * 4]; + + __m256 accb00, accb01, accb02, accb03; + __m256 accb10, accb11, accb12, accb13; + __m256 accb20, accb21, accb22, accb23; + __m256 accb30, accb31, accb32, accb33; + + { + const __m256i qy0v = _mm256_loadu_si256((const __m256i *) y0_ptr[0].qs); + const __m256i qy1v = _mm256_loadu_si256((const __m256i *) y1_ptr[0].qs); + const __m256i qy2v = _mm256_loadu_si256((const __m256i *) y2_ptr[0].qs); + const __m256i qy3v = _mm256_loadu_si256((const __m256i *) y3_ptr[0].qs); + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[0]), byte_shuf), bit_masks), zero); + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[0]), byte_shuf), bit_masks), zero); + const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs2[0]), byte_shuf), bit_masks), zero); + const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs3[0]), byte_shuf), bit_masks), zero); + const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[0].d)); + const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[0].d)); + const __m256 dy2 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y2_ptr[0].d)); + const __m256 dy3 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y3_ptr[0].d)); + +#define Q1_AVX2_INIT_4X4(ROW, SM, DROW) \ + const __m256i sy##ROW##0 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, SM), SM); \ + const __m256i sy##ROW##1 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, SM), SM); \ + const __m256i sy##ROW##2 = _mm256_sub_epi8(_mm256_xor_si256(qy2v, SM), SM); \ + const __m256i sy##ROW##3 = _mm256_sub_epi8(_mm256_xor_si256(qy3v, SM), SM); \ + accb##ROW##0 = _mm256_mul_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy##ROW##0), ones_16))); \ + accb##ROW##1 = _mm256_mul_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy##ROW##1), ones_16))); \ + accb##ROW##2 = _mm256_mul_ps(dy2, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy##ROW##2), ones_16))); \ + accb##ROW##3 = _mm256_mul_ps(dy3, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy##ROW##3), ones_16))); + Q1_AVX2_INIT_4X4(0, sm0, d0) + Q1_AVX2_INIT_4X4(1, sm1, d1) + Q1_AVX2_INIT_4X4(2, sm2, d2) + Q1_AVX2_INIT_4X4(3, sm3, d3) +#undef Q1_AVX2_INIT_4X4 + } + +#define Q1_AVX2_BLOCK_4X4(K) \ + { \ + const __m256i qy0v = _mm256_loadu_si256((const __m256i *) y0_ptr[K].qs); \ + const __m256i qy1v = _mm256_loadu_si256((const __m256i *) y1_ptr[K].qs); \ + const __m256i qy2v = _mm256_loadu_si256((const __m256i *) y2_ptr[K].qs); \ + const __m256i qy3v = _mm256_loadu_si256((const __m256i *) y3_ptr[K].qs); \ + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs2[K]), byte_shuf), bit_masks), zero); \ + const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs3[K]), byte_shuf), bit_masks), zero); \ + const __m256i sy00 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm0), sm0); \ + const __m256i sy01 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm0), sm0); \ + const __m256i sy02 = _mm256_sub_epi8(_mm256_xor_si256(qy2v, sm0), sm0); \ + const __m256i sy03 = _mm256_sub_epi8(_mm256_xor_si256(qy3v, sm0), sm0); \ + const __m256i sy10 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm1), sm1); \ + const __m256i sy11 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm1), sm1); \ + const __m256i sy12 = _mm256_sub_epi8(_mm256_xor_si256(qy2v, sm1), sm1); \ + const __m256i sy13 = _mm256_sub_epi8(_mm256_xor_si256(qy3v, sm1), sm1); \ + const __m256i sy20 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm2), sm2); \ + const __m256i sy21 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm2), sm2); \ + const __m256i sy22 = _mm256_sub_epi8(_mm256_xor_si256(qy2v, sm2), sm2); \ + const __m256i sy23 = _mm256_sub_epi8(_mm256_xor_si256(qy3v, sm2), sm2); \ + const __m256i sy30 = _mm256_sub_epi8(_mm256_xor_si256(qy0v, sm3), sm3); \ + const __m256i sy31 = _mm256_sub_epi8(_mm256_xor_si256(qy1v, sm3), sm3); \ + const __m256i sy32 = _mm256_sub_epi8(_mm256_xor_si256(qy2v, sm3), sm3); \ + const __m256i sy33 = _mm256_sub_epi8(_mm256_xor_si256(qy3v, sm3), sm3); \ + const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[K].d)); \ + const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[K].d)); \ + const __m256 dy2 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y2_ptr[K].d)); \ + const __m256 dy3 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y3_ptr[K].d)); \ + accb00 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy00), ones_16)), accb00); \ + accb01 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy01), ones_16)), accb01); \ + accb02 = _mm256_fmadd_ps(dy2, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy02), ones_16)), accb02); \ + accb03 = _mm256_fmadd_ps(dy3, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy03), ones_16)), accb03); \ + accb10 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy10), ones_16)), accb10); \ + accb11 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy11), ones_16)), accb11); \ + accb12 = _mm256_fmadd_ps(dy2, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy12), ones_16)), accb12); \ + accb13 = _mm256_fmadd_ps(dy3, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy13), ones_16)), accb13); \ + accb20 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy20), ones_16)), accb20); \ + accb21 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy21), ones_16)), accb21); \ + accb22 = _mm256_fmadd_ps(dy2, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy22), ones_16)), accb22); \ + accb23 = _mm256_fmadd_ps(dy3, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy23), ones_16)), accb23); \ + accb30 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy30), ones_16)), accb30); \ + accb31 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy31), ones_16)), accb31); \ + accb32 = _mm256_fmadd_ps(dy2, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy32), ones_16)), accb32); \ + accb33 = _mm256_fmadd_ps(dy3, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy33), ones_16)), accb33); \ + } + Q1_AVX2_BLOCK_4X4(1) + Q1_AVX2_BLOCK_4X4(2) + Q1_AVX2_BLOCK_4X4(3) +#undef Q1_AVX2_BLOCK_4X4 + + acc00 = _mm256_fmadd_ps(_mm256_set1_ps(d0), accb00, acc00); + acc01 = _mm256_fmadd_ps(_mm256_set1_ps(d0), accb01, acc01); + acc02 = _mm256_fmadd_ps(_mm256_set1_ps(d0), accb02, acc02); + acc03 = _mm256_fmadd_ps(_mm256_set1_ps(d0), accb03, acc03); + acc10 = _mm256_fmadd_ps(_mm256_set1_ps(d1), accb10, acc10); + acc11 = _mm256_fmadd_ps(_mm256_set1_ps(d1), accb11, acc11); + acc12 = _mm256_fmadd_ps(_mm256_set1_ps(d1), accb12, acc12); + acc13 = _mm256_fmadd_ps(_mm256_set1_ps(d1), accb13, acc13); + acc20 = _mm256_fmadd_ps(_mm256_set1_ps(d2), accb20, acc20); + acc21 = _mm256_fmadd_ps(_mm256_set1_ps(d2), accb21, acc21); + acc22 = _mm256_fmadd_ps(_mm256_set1_ps(d2), accb22, acc22); + acc23 = _mm256_fmadd_ps(_mm256_set1_ps(d2), accb23, acc23); + acc30 = _mm256_fmadd_ps(_mm256_set1_ps(d3), accb30, acc30); + acc31 = _mm256_fmadd_ps(_mm256_set1_ps(d3), accb31, acc31); + acc32 = _mm256_fmadd_ps(_mm256_set1_ps(d3), accb32, acc32); + acc33 = _mm256_fmadd_ps(_mm256_set1_ps(d3), accb33, acc33); + } + + s[0] = hsum_float_8(acc00); + s[1] = hsum_float_8(acc10); + s[2] = hsum_float_8(acc20); + s[3] = hsum_float_8(acc30); + s[bs + 0] = hsum_float_8(acc01); + s[bs + 1] = hsum_float_8(acc11); + s[bs + 2] = hsum_float_8(acc21); + s[bs + 3] = hsum_float_8(acc31); + s[2*bs + 0] = hsum_float_8(acc02); + s[2*bs + 1] = hsum_float_8(acc12); + s[2*bs + 2] = hsum_float_8(acc22); + s[2*bs + 3] = hsum_float_8(acc32); + s[3*bs + 0] = hsum_float_8(acc03); + s[3*bs + 1] = hsum_float_8(acc13); + s[3*bs + 2] = hsum_float_8(acc23); + s[3*bs + 3] = hsum_float_8(acc33); + return; +#endif + + ggml_vec_dot_q1_0_q8_0_4x2(n, s, bs, vx, bx, vy, by); + ggml_vec_dot_q1_0_q8_0_4x2(n, s + 2*bs, bs, vx, bx, (const uint8_t *) vy + 2*by, by); +} void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 88a9c9ec057..b219db926d5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -29,6 +29,11 @@ struct ggml_compute_params { bool use_ref; }; +void ggml_vec_dot_q1_0_q8_0_4x1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy); +void ggml_vec_dot_q1_0_q8_0_2x1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy); +void ggml_vec_dot_q1_0_q8_0_4x2(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by); +void ggml_vec_dot_q1_0_q8_0_4x4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by); + #if defined(_MSC_VER) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2b3eb5b5ce6..4d2f1f01d33 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -221,7 +221,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q1_0, .vec_dot = ggml_vec_dot_q1_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, +#if defined(__AVX2__) || defined(__AVX__) || defined(__SSSE3__) + .nrows = 2, +#else .nrows = 1, +#endif }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, @@ -1186,14 +1190,17 @@ static void ggml_compute_forward_mul_mat_one_chunk( assert(ne13 % ne03 == 0); // block-tiling attempt - const int64_t blck_0 = 16; + // Q1_0 weights are much smaller than their Q8_0 dot operand, so a wider row tile + // improves src1_col reuse without increasing the simultaneously hot column set. + const int64_t blck_0 = type == GGML_TYPE_Q1_0 ? 32 : 16; const int64_t blck_1 = 16; const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; // attempt to reduce false-sharing (does not seem to make a difference) - // 16 * 2, accounting for mmla kernels - float tmp[32]; + // Q1_0 may use up to a 32x4 rectangular tile for AVX-512-only experiments. + float tmp[128]; + GGML_ASSERT(blck_0 * num_rows_per_vec_dot <= (int64_t) (sizeof(tmp) / sizeof(tmp[0]))); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -1226,12 +1233,68 @@ static void ggml_compute_forward_mul_mat_one_chunk( // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} +#if defined(__AVX2__) && (defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)) + if (type == GGML_TYPE_Q1_0 && num_rows_per_vec_dot == 2) { + const int64_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + +#if defined(__AVX512F__) + if (i11 + 4 <= ne11 && ir1 + 2 < ir1_end && ir1 + 2 < iir1 + blck_1) { + for (int64_t ir0 = iir0; ir0 < ir0_block_end; ) { + if (ir0 + 4 <= ir0_block_end) { + ggml_vec_dot_q1_0_q8_0_4x4(ne00, &tmp[ir0 - iir0], blck_0, src0_row + ir0 * nb01, nb01, src1_col, src1_col_stride); + ir0 += 4; + continue; + } + + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col, 0, 1); + vec_dot(ne00, &tmp[blck_0 + ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col + src1_col_stride, 0, 1); + vec_dot(ne00, &tmp[2*blck_0 + ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col + 2 * src1_col_stride, 0, 1); + vec_dot(ne00, &tmp[3*blck_0 + ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col + 3 * src1_col_stride, 0, 1); + ++ir0; + } + + for (int cn = 0; cn < 4; ++cn) { + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (ir0_block_end - iir0) * sizeof(float)); + } + + ir1 += 2; + continue; + } +#endif + for (int64_t ir0 = iir0; ir0 < ir0_block_end; ) { + if (ir0 + 4 <= ir0_block_end) { + ggml_vec_dot_q1_0_q8_0_4x2(ne00, &tmp[ir0 - iir0], blck_0, src0_row + ir0 * nb01, nb01, src1_col, src1_col_stride); + ir0 += 4; + continue; + } + + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col, 0, 1); + vec_dot(ne00, &tmp[blck_0 + ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col + src1_col_stride, 0, 1); + ++ir0; + } + } + else if (type == GGML_TYPE_Q1_0 && num_rows_per_vec_dot == 1) { + const int64_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + + for (int64_t ir0 = iir0; ir0 < ir0_block_end; ) { + if (ir0 + 2 <= ir0_block_end) { + ggml_vec_dot_q1_0_q8_0_2x1(ne00, &tmp[ir0 - iir0], blck_0, src0_row + ir0 * nb01, nb01, src1_col); + ir0 += 2; + continue; + } + + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0 * nb01, 0, src1_col, 0, 1); + ++ir0; + } + } + else +#endif for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { - memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } } } diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index f66127c2290..6ac647cf9fc 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -142,17 +142,23 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c float sumi = 0.0f; for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_FP16_TO_FP32(yb->d); int sumi_block = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi_block += xi * y[i*4 + k].qs[j]; + const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); } sumi += d1 * sumi_block;