@@ -760,28 +760,31 @@ void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, cons
760760
761761 for (int ib = 0 ; ib < nb ; ++ ib ) {
762762 const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
763+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
763764 __m256 acc_block = _mm256_setzero_ps ();
764-
765- for (int k = 0 ; k < 4 ; ++ k ) {
766- const block_q8_0 * GGML_RESTRICT yb = & y [ib * 4 + k ];
767- const float d1 = GGML_CPU_FP16_TO_FP32 (yb -> d );
768- const __m256i bit_mask = bytes_from_bits_32 (& x [ib ].qs [k * 4 ]);
769- const __m128i bit_mask_0 = _mm256_castsi256_si128 (bit_mask );
770- const __m128i bit_mask_1 = _mm256_extractf128_si256 (bit_mask , 1 );
771- const __m128i qy_0 = _mm_loadu_si128 ((const __m128i * ) & yb -> qs [0 ]);
772- const __m128i qy_1 = _mm_loadu_si128 ((const __m128i * ) & yb -> qs [16 ]);
773- const __m128i sign_mask_0 = _mm_cmpeq_epi8 (bit_mask_0 , zero );
774- const __m128i sign_mask_1 = _mm_cmpeq_epi8 (bit_mask_1 , zero );
775- const __m128i sy_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_0 , sign_mask_0 ), sign_mask_0 );
776- const __m128i sy_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_1 , sign_mask_1 ), sign_mask_1 );
777- const __m128i sum16_0 = _mm_maddubs_epi16 (ones_8 , sy_0 );
778- const __m128i sum16_1 = _mm_maddubs_epi16 (ones_8 , sy_1 );
779- const __m128i sum32_0 = _mm_madd_epi16 (sum16_0 , ones_16 );
780- const __m128i sum32_1 = _mm_madd_epi16 (sum16_1 , ones_16 );
781- const __m256 q = _mm256_cvtepi32_ps (MM256_SET_M128I (sum32_1 , sum32_0 ));
782-
783- acc_block = _mm256_add_ps (acc_block , _mm256_mul_ps (_mm256_set1_ps (d1 ), q ));
765+ #define Q1_AVX_BLOCK (K ) \
766+ { \
767+ const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); \
768+ const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); \
769+ const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); \
770+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \
771+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \
772+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
773+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
774+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
775+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
776+ const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); \
777+ const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); \
778+ const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); \
779+ const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); \
780+ const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); \
781+ acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); \
784782 }
783+ Q1_AVX_BLOCK (0 )
784+ Q1_AVX_BLOCK (1 )
785+ Q1_AVX_BLOCK (2 )
786+ Q1_AVX_BLOCK (3 )
787+ #undef Q1_AVX_BLOCK
785788
786789 acc = _mm256_add_ps (acc , _mm256_mul_ps (_mm256_set1_ps (d0 ), acc_block ));
787790 }
@@ -801,62 +804,28 @@ void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, cons
801804
802805 for (int ib = 0 ; ib < nb ; ++ ib ) {
803806 const __m128 d0 = _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x [ib ].d ));
804- const block_q8_0 * GGML_RESTRICT yb_0 = & y [ib * 4 + 0 ];
805- const block_q8_0 * GGML_RESTRICT yb_1 = & y [ib * 4 + 1 ];
806- const block_q8_0 * GGML_RESTRICT yb_2 = & y [ib * 4 + 2 ];
807- const block_q8_0 * GGML_RESTRICT yb_3 = & y [ib * 4 + 3 ];
808-
809- const __m128i bit_mask_0_0 = bytes_from_bits_16 (& x [ib ].qs [0 ]);
810- const __m128i bit_mask_0_1 = bytes_from_bits_16 (& x [ib ].qs [2 ]);
811- const __m128i qy_0_0 = _mm_loadu_si128 ((const __m128i * ) & yb_0 -> qs [0 ]);
812- const __m128i qy_0_1 = _mm_loadu_si128 ((const __m128i * ) & yb_0 -> qs [16 ]);
813- const __m128i sign_mask_0_0 = _mm_cmpeq_epi8 (bit_mask_0_0 , zero );
814- const __m128i sign_mask_0_1 = _mm_cmpeq_epi8 (bit_mask_0_1 , zero );
815- const __m128i sy_0_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_0_0 , sign_mask_0_0 ), sign_mask_0_0 );
816- const __m128i sy_0_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_0_1 , sign_mask_0_1 ), sign_mask_0_1 );
817- const __m128i sum_0_0 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_0_0 ), ones_16 );
818- const __m128i sum_0_1 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_0_1 ), ones_16 );
819- const __m128 q_0 = _mm_cvtepi32_ps (_mm_add_epi32 (sum_0_0 , sum_0_1 ));
820- acc_0 = _mm_add_ps (acc_0 , _mm_mul_ps (_mm_mul_ps (d0 , _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (yb_0 -> d ))), q_0 ));
821-
822- const __m128i bit_mask_1_0 = bytes_from_bits_16 (& x [ib ].qs [4 ]);
823- const __m128i bit_mask_1_1 = bytes_from_bits_16 (& x [ib ].qs [6 ]);
824- const __m128i qy_1_0 = _mm_loadu_si128 ((const __m128i * ) & yb_1 -> qs [0 ]);
825- const __m128i qy_1_1 = _mm_loadu_si128 ((const __m128i * ) & yb_1 -> qs [16 ]);
826- const __m128i sign_mask_1_0 = _mm_cmpeq_epi8 (bit_mask_1_0 , zero );
827- const __m128i sign_mask_1_1 = _mm_cmpeq_epi8 (bit_mask_1_1 , zero );
828- const __m128i sy_1_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_1_0 , sign_mask_1_0 ), sign_mask_1_0 );
829- const __m128i sy_1_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_1_1 , sign_mask_1_1 ), sign_mask_1_1 );
830- const __m128i sum_1_0 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_1_0 ), ones_16 );
831- const __m128i sum_1_1 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_1_1 ), ones_16 );
832- const __m128 q_1 = _mm_cvtepi32_ps (_mm_add_epi32 (sum_1_0 , sum_1_1 ));
833- acc_1 = _mm_add_ps (acc_1 , _mm_mul_ps (_mm_mul_ps (d0 , _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (yb_1 -> d ))), q_1 ));
834-
835- const __m128i bit_mask_2_0 = bytes_from_bits_16 (& x [ib ].qs [8 ]);
836- const __m128i bit_mask_2_1 = bytes_from_bits_16 (& x [ib ].qs [10 ]);
837- const __m128i qy_2_0 = _mm_loadu_si128 ((const __m128i * ) & yb_2 -> qs [0 ]);
838- const __m128i qy_2_1 = _mm_loadu_si128 ((const __m128i * ) & yb_2 -> qs [16 ]);
839- const __m128i sign_mask_2_0 = _mm_cmpeq_epi8 (bit_mask_2_0 , zero );
840- const __m128i sign_mask_2_1 = _mm_cmpeq_epi8 (bit_mask_2_1 , zero );
841- const __m128i sy_2_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_2_0 , sign_mask_2_0 ), sign_mask_2_0 );
842- const __m128i sy_2_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_2_1 , sign_mask_2_1 ), sign_mask_2_1 );
843- const __m128i sum_2_0 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_2_0 ), ones_16 );
844- const __m128i sum_2_1 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_2_1 ), ones_16 );
845- const __m128 q_2 = _mm_cvtepi32_ps (_mm_add_epi32 (sum_2_0 , sum_2_1 ));
846- acc_2 = _mm_add_ps (acc_2 , _mm_mul_ps (_mm_mul_ps (d0 , _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (yb_2 -> d ))), q_2 ));
847-
848- const __m128i bit_mask_3_0 = bytes_from_bits_16 (& x [ib ].qs [12 ]);
849- const __m128i bit_mask_3_1 = bytes_from_bits_16 (& x [ib ].qs [14 ]);
850- const __m128i qy_3_0 = _mm_loadu_si128 ((const __m128i * ) & yb_3 -> qs [0 ]);
851- const __m128i qy_3_1 = _mm_loadu_si128 ((const __m128i * ) & yb_3 -> qs [16 ]);
852- const __m128i sign_mask_3_0 = _mm_cmpeq_epi8 (bit_mask_3_0 , zero );
853- const __m128i sign_mask_3_1 = _mm_cmpeq_epi8 (bit_mask_3_1 , zero );
854- const __m128i sy_3_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_3_0 , sign_mask_3_0 ), sign_mask_3_0 );
855- const __m128i sy_3_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_3_1 , sign_mask_3_1 ), sign_mask_3_1 );
856- const __m128i sum_3_0 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_3_0 ), ones_16 );
857- const __m128i sum_3_1 = _mm_madd_epi16 (_mm_maddubs_epi16 (ones_8 , sy_3_1 ), ones_16 );
858- const __m128 q_3 = _mm_cvtepi32_ps (_mm_add_epi32 (sum_3_0 , sum_3_1 ));
859- acc_3 = _mm_add_ps (acc_3 , _mm_mul_ps (_mm_mul_ps (d0 , _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (yb_3 -> d ))), q_3 ));
807+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
808+
809+ #define Q1_SSSE3_BLOCK (QS_OFF , Y_IDX , ACC ) \
810+ { \
811+ const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \
812+ const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \
813+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \
814+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \
815+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
816+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
817+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
818+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
819+ const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \
820+ const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \
821+ const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \
822+ (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \
823+ }
824+ Q1_SSSE3_BLOCK (0 , 0 , acc_0 )
825+ Q1_SSSE3_BLOCK (4 , 1 , acc_1 )
826+ Q1_SSSE3_BLOCK (8 , 2 , acc_2 )
827+ Q1_SSSE3_BLOCK (12 , 3 , acc_3 )
828+ #undef Q1_SSSE3_BLOCK
860829 }
861830
862831 * s = hsum_float_4x4 (acc_0 , acc_1 , acc_2 , acc_3 );
0 commit comments