Skip to content

Commit b793ed1

Browse files
committed
Refactored Q1_0_g128 code visual structure
1 parent 167652c commit b793ed1

1 file changed

Lines changed: 45 additions & 76 deletions

File tree

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 45 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -760,28 +760,31 @@ void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, cons
760760

761761
for (int ib = 0; ib < nb; ++ib) {
762762
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
763+
const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4];
763764
__m256 acc_block = _mm256_setzero_ps();
764-
765-
for (int k = 0; k < 4; ++k) {
766-
const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k];
767-
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);
768-
const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[k * 4]);
769-
const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask);
770-
const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1);
771-
const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &yb->qs[0]);
772-
const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &yb->qs[16]);
773-
const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero);
774-
const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero);
775-
const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0);
776-
const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1);
777-
const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0);
778-
const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1);
779-
const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16);
780-
const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16);
781-
const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0));
782-
783-
acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(d1), q));
765+
#define Q1_AVX_BLOCK(K) \
766+
{ \
767+
const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); \
768+
const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); \
769+
const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); \
770+
const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \
771+
const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \
772+
const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
773+
const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
774+
const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
775+
const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
776+
const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); \
777+
const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); \
778+
const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); \
779+
const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); \
780+
const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); \
781+
acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); \
784782
}
783+
Q1_AVX_BLOCK(0)
784+
Q1_AVX_BLOCK(1)
785+
Q1_AVX_BLOCK(2)
786+
Q1_AVX_BLOCK(3)
787+
#undef Q1_AVX_BLOCK
785788

786789
acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block));
787790
}
@@ -801,62 +804,28 @@ void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, cons
801804

802805
for (int ib = 0; ib < nb; ++ib) {
803806
const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
804-
const block_q8_0 * GGML_RESTRICT yb_0 = &y[ib * 4 + 0];
805-
const block_q8_0 * GGML_RESTRICT yb_1 = &y[ib * 4 + 1];
806-
const block_q8_0 * GGML_RESTRICT yb_2 = &y[ib * 4 + 2];
807-
const block_q8_0 * GGML_RESTRICT yb_3 = &y[ib * 4 + 3];
808-
809-
const __m128i bit_mask_0_0 = bytes_from_bits_16(&x[ib].qs[0]);
810-
const __m128i bit_mask_0_1 = bytes_from_bits_16(&x[ib].qs[2]);
811-
const __m128i qy_0_0 = _mm_loadu_si128((const __m128i *) &yb_0->qs[0]);
812-
const __m128i qy_0_1 = _mm_loadu_si128((const __m128i *) &yb_0->qs[16]);
813-
const __m128i sign_mask_0_0 = _mm_cmpeq_epi8(bit_mask_0_0, zero);
814-
const __m128i sign_mask_0_1 = _mm_cmpeq_epi8(bit_mask_0_1, zero);
815-
const __m128i sy_0_0 = _mm_sub_epi8(_mm_xor_si128(qy_0_0, sign_mask_0_0), sign_mask_0_0);
816-
const __m128i sy_0_1 = _mm_sub_epi8(_mm_xor_si128(qy_0_1, sign_mask_0_1), sign_mask_0_1);
817-
const __m128i sum_0_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0_0), ones_16);
818-
const __m128i sum_0_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0_1), ones_16);
819-
const __m128 q_0 = _mm_cvtepi32_ps(_mm_add_epi32(sum_0_0, sum_0_1));
820-
acc_0 = _mm_add_ps(acc_0, _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(yb_0->d))), q_0));
821-
822-
const __m128i bit_mask_1_0 = bytes_from_bits_16(&x[ib].qs[4]);
823-
const __m128i bit_mask_1_1 = bytes_from_bits_16(&x[ib].qs[6]);
824-
const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *) &yb_1->qs[0]);
825-
const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *) &yb_1->qs[16]);
826-
const __m128i sign_mask_1_0 = _mm_cmpeq_epi8(bit_mask_1_0, zero);
827-
const __m128i sign_mask_1_1 = _mm_cmpeq_epi8(bit_mask_1_1, zero);
828-
const __m128i sy_1_0 = _mm_sub_epi8(_mm_xor_si128(qy_1_0, sign_mask_1_0), sign_mask_1_0);
829-
const __m128i sy_1_1 = _mm_sub_epi8(_mm_xor_si128(qy_1_1, sign_mask_1_1), sign_mask_1_1);
830-
const __m128i sum_1_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1_0), ones_16);
831-
const __m128i sum_1_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1_1), ones_16);
832-
const __m128 q_1 = _mm_cvtepi32_ps(_mm_add_epi32(sum_1_0, sum_1_1));
833-
acc_1 = _mm_add_ps(acc_1, _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(yb_1->d))), q_1));
834-
835-
const __m128i bit_mask_2_0 = bytes_from_bits_16(&x[ib].qs[8]);
836-
const __m128i bit_mask_2_1 = bytes_from_bits_16(&x[ib].qs[10]);
837-
const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *) &yb_2->qs[0]);
838-
const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *) &yb_2->qs[16]);
839-
const __m128i sign_mask_2_0 = _mm_cmpeq_epi8(bit_mask_2_0, zero);
840-
const __m128i sign_mask_2_1 = _mm_cmpeq_epi8(bit_mask_2_1, zero);
841-
const __m128i sy_2_0 = _mm_sub_epi8(_mm_xor_si128(qy_2_0, sign_mask_2_0), sign_mask_2_0);
842-
const __m128i sy_2_1 = _mm_sub_epi8(_mm_xor_si128(qy_2_1, sign_mask_2_1), sign_mask_2_1);
843-
const __m128i sum_2_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_2_0), ones_16);
844-
const __m128i sum_2_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_2_1), ones_16);
845-
const __m128 q_2 = _mm_cvtepi32_ps(_mm_add_epi32(sum_2_0, sum_2_1));
846-
acc_2 = _mm_add_ps(acc_2, _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(yb_2->d))), q_2));
847-
848-
const __m128i bit_mask_3_0 = bytes_from_bits_16(&x[ib].qs[12]);
849-
const __m128i bit_mask_3_1 = bytes_from_bits_16(&x[ib].qs[14]);
850-
const __m128i qy_3_0 = _mm_loadu_si128((const __m128i *) &yb_3->qs[0]);
851-
const __m128i qy_3_1 = _mm_loadu_si128((const __m128i *) &yb_3->qs[16]);
852-
const __m128i sign_mask_3_0 = _mm_cmpeq_epi8(bit_mask_3_0, zero);
853-
const __m128i sign_mask_3_1 = _mm_cmpeq_epi8(bit_mask_3_1, zero);
854-
const __m128i sy_3_0 = _mm_sub_epi8(_mm_xor_si128(qy_3_0, sign_mask_3_0), sign_mask_3_0);
855-
const __m128i sy_3_1 = _mm_sub_epi8(_mm_xor_si128(qy_3_1, sign_mask_3_1), sign_mask_3_1);
856-
const __m128i sum_3_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_3_0), ones_16);
857-
const __m128i sum_3_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_3_1), ones_16);
858-
const __m128 q_3 = _mm_cvtepi32_ps(_mm_add_epi32(sum_3_0, sum_3_1));
859-
acc_3 = _mm_add_ps(acc_3, _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(yb_3->d))), q_3));
807+
const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4];
808+
809+
#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \
810+
{ \
811+
const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \
812+
const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \
813+
const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \
814+
const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \
815+
const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
816+
const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
817+
const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
818+
const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
819+
const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \
820+
const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \
821+
const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \
822+
(ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \
823+
}
824+
Q1_SSSE3_BLOCK(0, 0, acc_0)
825+
Q1_SSSE3_BLOCK(4, 1, acc_1)
826+
Q1_SSSE3_BLOCK(8, 2, acc_2)
827+
Q1_SSSE3_BLOCK(12, 3, acc_3)
828+
#undef Q1_SSSE3_BLOCK
860829
}
861830

862831
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);

0 commit comments

Comments
 (0)