-
Notifications
You must be signed in to change notification settings - Fork 31
x86: implement AVX2 kernel for ggml_vec_dot_q1_0_g128_q8_0 #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -545,7 +545,80 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||||||||||||||||||
| #if defined(__AVX2__) | ||||||||||||||||||
| const int nb = n / QK1_0_g128; | ||||||||||||||||||
| GGML_ASSERT(n % QK1_0_g128 == 0); | ||||||||||||||||||
| GGML_ASSERT(nrc == 1); | ||||||||||||||||||
| UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); | ||||||||||||||||||
|
|
||||||||||||||||||
| const block_q1_0_g128 * GGML_RESTRICT x = vx; | ||||||||||||||||||
| const block_q8_0 * GGML_RESTRICT y = vy; | ||||||||||||||||||
|
|
||||||||||||||||||
| // Bit position mask: LSB first, repeated 4x for 32 lanes | ||||||||||||||||||
| static const int8_t bitmask_data[32] = { | ||||||||||||||||||
| 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, | ||||||||||||||||||
| 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, | ||||||||||||||||||
| 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, | ||||||||||||||||||
| 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, | ||||||||||||||||||
| }; | ||||||||||||||||||
| // vpshufb masks: expand one byte to 8 consecutive lanes | ||||||||||||||||||
| static const int8_t shuf_lo_data[16] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1 }; | ||||||||||||||||||
| static const int8_t shuf_hi_data[16] = { 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3 }; | ||||||||||||||||||
|
|
||||||||||||||||||
| const __m256i bit_mask = _mm256_loadu_si256((const __m256i *)bitmask_data); | ||||||||||||||||||
| const __m128i shuf_lo = _mm_loadu_si128 ((const __m128i *)shuf_lo_data); | ||||||||||||||||||
| const __m128i shuf_hi = _mm_loadu_si128 ((const __m128i *)shuf_hi_data); | ||||||||||||||||||
| const __m256i zero_vec = _mm256_setzero_si256(); | ||||||||||||||||||
| const __m256i one8 = _mm256_set1_epi8(1); | ||||||||||||||||||
| const __m256i ones16 = _mm256_set1_epi16(1); | ||||||||||||||||||
|
|
||||||||||||||||||
| // Four independent float accumulators — eliminates FMA dependency chain | ||||||||||||||||||
| // (FMA latency 5 cycles on Skylake; 4 independent regs hide this completely) | ||||||||||||||||||
| __m256 sumf0 = _mm256_setzero_ps(); | ||||||||||||||||||
| __m256 sumf1 = _mm256_setzero_ps(); | ||||||||||||||||||
| __m256 sumf2 = _mm256_setzero_ps(); | ||||||||||||||||||
| __m256 sumf3 = _mm256_setzero_ps(); | ||||||||||||||||||
|
|
||||||||||||||||||
| for (int i = 0; i < nb; i++) { | ||||||||||||||||||
| const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); | ||||||||||||||||||
|
|
||||||||||||||||||
| // Load 16 bytes of bits covering all 4 sub-blocks | ||||||||||||||||||
| uint64_t bits_lo, bits_hi; | ||||||||||||||||||
| memcpy(&bits_lo, &x[i].qs[0], 8); | ||||||||||||||||||
| memcpy(&bits_hi, &x[i].qs[8], 8); | ||||||||||||||||||
|
|
||||||||||||||||||
| // Expand each sub-block's 4 bytes via vpshufb | ||||||||||||||||||
| const __m128i raw0 = _mm_cvtsi32_si128((int)(uint32_t)bits_lo); | ||||||||||||||||||
| const __m128i raw1 = _mm_cvtsi32_si128((int)(uint32_t)(bits_lo >> 32)); | ||||||||||||||||||
| const __m128i raw2 = _mm_cvtsi32_si128((int)(uint32_t)bits_hi); | ||||||||||||||||||
| const __m128i raw3 = _mm_cvtsi32_si128((int)(uint32_t)(bits_hi >> 32)); | ||||||||||||||||||
|
|
||||||||||||||||||
| const __m256i bv0 = MM256_SET_M128I(_mm_shuffle_epi8(raw0, shuf_hi), _mm_shuffle_epi8(raw0, shuf_lo)); | ||||||||||||||||||
| const __m256i bv1 = MM256_SET_M128I(_mm_shuffle_epi8(raw1, shuf_hi), _mm_shuffle_epi8(raw1, shuf_lo)); | ||||||||||||||||||
| const __m256i bv2 = MM256_SET_M128I(_mm_shuffle_epi8(raw2, shuf_hi), _mm_shuffle_epi8(raw2, shuf_lo)); | ||||||||||||||||||
| const __m256i bv3 = MM256_SET_M128I(_mm_shuffle_epi8(raw3, shuf_hi), _mm_shuffle_epi8(raw3, shuf_lo)); | ||||||||||||||||||
|
|
||||||||||||||||||
| #define DOT_SUB(bv, yb, acc) do { const __m256i yv = _mm256_loadu_si256((const __m256i *)(yb)->qs); /* cmpeq(AND(bv,mask),0): 0xFF where bit=0; OR 0x01: 0xFF(-1) where bit=0, 0x01(+1) where bit=1 */ const __m256i sgn = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256((bv), bit_mask), zero_vec), one8); const __m256i p32 = _mm256_madd_epi16( _mm256_maddubs_epi16(_mm256_abs_epi8(yv), _mm256_sign_epi8(sgn, yv)), ones16); (acc) = _mm256_fmadd_ps(_mm256_cvtepi32_ps(p32), _mm256_set1_ps(d0 * GGML_CPU_FP16_TO_FP32((yb)->d)), (acc)); } while (0) | ||||||||||||||||||
|
Comment on lines
+600
to
+601
|
||||||||||||||||||
| #define DOT_SUB(bv, yb, acc) do { const __m256i yv = _mm256_loadu_si256((const __m256i *)(yb)->qs); /* cmpeq(AND(bv,mask),0): 0xFF where bit=0; OR 0x01: 0xFF(-1) where bit=0, 0x01(+1) where bit=1 */ const __m256i sgn = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256((bv), bit_mask), zero_vec), one8); const __m256i p32 = _mm256_madd_epi16( _mm256_maddubs_epi16(_mm256_abs_epi8(yv), _mm256_sign_epi8(sgn, yv)), ones16); (acc) = _mm256_fmadd_ps(_mm256_cvtepi32_ps(p32), _mm256_set1_ps(d0 * GGML_CPU_FP16_TO_FP32((yb)->d)), (acc)); } while (0) | |
| #if defined(__FMA__) | |
| #define DOT_SUB(bv, yb, acc) do { const __m256i yv = _mm256_loadu_si256((const __m256i *)(yb)->qs); /* cmpeq(AND(bv,mask),0): 0xFF where bit=0; OR 0x01: 0xFF(-1) where bit=0, 0x01(+1) where bit=1 */ const __m256i sgn = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256((bv), bit_mask), zero_vec), one8); const __m256i p32 = _mm256_madd_epi16( _mm256_maddubs_epi16(_mm256_abs_epi8(yv), _mm256_sign_epi8(sgn, yv)), ones16); (acc) = _mm256_fmadd_ps(_mm256_cvtepi32_ps(p32), _mm256_set1_ps(d0 * GGML_CPU_FP16_TO_FP32((yb)->d)), (acc)); } while (0) | |
| #else | |
| #define DOT_SUB(bv, yb, acc) do { const __m256i yv = _mm256_loadu_si256((const __m256i *)(yb)->qs); /* cmpeq(AND(bv,mask),0): 0xFF where bit=0; OR 0x01: 0xFF(-1) where bit=0, 0x01(+1) where bit=1 */ const __m256i sgn = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256((bv), bit_mask), zero_vec), one8); const __m256i p32 = _mm256_madd_epi16( _mm256_maddubs_epi16(_mm256_abs_epi8(yv), _mm256_sign_epi8(sgn, yv)), ones16); (acc) = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(p32), _mm256_set1_ps(d0 * GGML_CPU_FP16_TO_FP32((yb)->d))), (acc)); } while (0) | |
| #endif |
Copilot
AI
Apr 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DOT_SUB is defined as a very large single-line macro inside the loop body, which makes the kernel hard to read, debug, and maintain (and increases the risk of subtle macro issues if the arguments ever change). Consider replacing it with a small static inline helper function (or at least a multi-line macro defined outside the loop) to improve maintainability without affecting performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function uses GGML_ASSERT for argument checks, but the other vec_dot kernels in this file consistently use assert() (e.g., ggml_vec_dot_q4_0_q8_0 just below). For consistency and to avoid changing behavior in release builds (GGML_ASSERT may not be compiled out like assert), consider switching these to assert() like the rest of the file or documenting why GGML_ASSERT is required here.