1+ #pragma once
2+
13#include < cstdint>
24#include < cmath>
35#include < chess.hpp>
6+ #include " simd.hpp"
47
5- #if defined(__AVX2__)
6- #include < immintrin.h>
7-
8- using vec_int8 = __m256i;
9- using vec_uint8 = __m256i;
10- using vec_int16 = __m256i;
11- using vec_uint16 = __m256i;
12- using vec_int32 = __m256i;
13- using vec_uint32 = __m256i;
14-
15- inline vec_int8 setzero_epi8 () {
16- return _mm256_setzero_si256 ();
17- }
18-
19- inline vec_int16 setzero_epi16 () {
20- return _mm256_setzero_si256 ();
21- }
22-
23- inline vec_int32 setzero_epi32 () {
24- return _mm256_setzero_si256 ();
25- }
8+ constexpr int NUM_AVX_REGISTERS = 8 ;
9+ constexpr int INT32_PER_REG = sizeof (vec_int32) / sizeof (int32_t );
10+ constexpr int INT16_PER_REG = sizeof (vec_int16) / sizeof (int16_t );
11+ constexpr int INT8_PER_REG = sizeof (vec_int8) / sizeof (int8_t );
2612
27- inline vec_int16 set1_epi16 (int i) {
28- return _mm256_set1_epi16 (i);
29- }
13+ namespace NNUE_UTILS {
3014
31- inline vec_int32 set1_epi32 ( int i) {
32- return _mm256_set1_epi32 (i);
33- }
15+ # ifdef USE_AVX2 // these functions use the AVX2 specific instruction permute4x64_epi64
16+ [[maybe_unused]]
17+ inline void crelu32_to_8 ( int32_t *input, int8_t *output, int size){
3418
35- inline vec_int8 load_epi8 (int8_t * ptr) {
36- return _mm256_loadu_si256 ((const __m256i*)ptr);
37- }
19+ assert (size % INT8_PER_REG == 0 );
3820
39- inline vec_int16 load_epi16 (int16_t * ptr) {
40- return _mm256_loadu_si256 ((const __m256i*)ptr);
41- }
21+ const int num_regs = size / INT8_PER_REG;
22+ const vec_int8 zero = setzero_epi8 ();
4223
43- inline vec_int32 load_epi32 (int32_t * ptr) {
44- return _mm256_loadu_si256 ((const __m256i*)ptr);
45- }
24+ for (int i = 0 ; i < num_regs; i++){
25+ vec_int32 in_1 = load_epi32 (&input[(4 *i)*INT32_PER_REG]);
26+ vec_int32 in_2 = load_epi32 (&input[(4 *i+1 )*INT32_PER_REG]);
27+ vec_int32 in_3 = load_epi32 (&input[(4 *i+2 )*INT32_PER_REG]);
28+ vec_int32 in_4 = load_epi32 (&input[(4 *i+3 )*INT32_PER_REG]);
4629
47- inline void store_epi8 (int8_t * ptr, vec_int8 v) {
48- _mm256_storeu_si256 ((__m256i*)ptr, v);
49- }
30+ in_1 = permute4x64_epi64<0b10'00'11'01 >(packs_epi32 (in_1, in_2));
31+ in_3 = permute4x64_epi64<0b10'00'11'01 >(packs_epi32 (in_3, in_4));
5032
51- inline void store_epi16 (int16_t * ptr, vec_int16 v) {
52- _mm256_storeu_si256 ((__m256i*)ptr, v);
33+ vec_int8 out = packs_epi16 (in_1, in_3);
34+ out = max_epi8 (out, zero); // packs saturates at 127, so only max is applied
35+ out = permute4x64_epi64<0b01'11'00'10 >(out);
36+ store_epi8 (&output[i*INT8_PER_REG], out);
37+ }
5338 }
5439
55- inline void store_epi32 (int32_t * ptr, vec_int32 v) {
56- _mm256_storeu_si256 ((__m256i*)ptr, v);
57- }
40+ [[maybe_unused]]
41+ inline void crelu16_to_8 (int16_t *input, int8_t *output, int size){
5842
59- inline vec_int8 packs_epi16 (vec_int16 v1, vec_int16 v2) {
60- return _mm256_packs_epi16 (v1, v2);
61- }
43+ assert (size % INT8_PER_REG == 0 );
6244
63- inline vec_int16 packs_epi32 (vec_int32 v1, vec_int32 v2) {
64- return _mm256_packs_epi32 (v1, v2);
65- }
45+ const int num_regs = size / INT8_PER_REG;
6646
67- inline vec_int8 packus_epi16 (vec_int16 v1, vec_int16 v2) {
68- return _mm256_packus_epi16 (v1, v2);
69- }
70-
71- inline vec_int16 packus_epi32 (vec_int32 v1, vec_int32 v2) {
72- return _mm256_packus_epi32 (v1, v2);
47+ for (int i = 0 ; i < num_regs; i++){
48+ vec_int16 in_1 = load_epi16 (&input[(2 *i)*INT16_PER_REG]);
49+ vec_int16 in_2 = load_epi16 (&input[(2 *i+1 )*INT16_PER_REG]);
50+ // packs sets negative values to 0 and saturates at 255, which effectively applies relu
51+ vec_int8 out = packus_epi16 (in_1, in_2);
52+ out = permute4x64_epi64<0b11'01'10'00 >(out); // undo the packus shuffle
53+ store_epi8 (&output[i*INT8_PER_REG], out);
54+ }
7355 }
56+ #endif
7457
75- template <int mask>
76- inline vec_int32 permute4x64_epi64 (vec_int32 v) {
77- return _mm256_permute4x64_epi64 (v, mask);
78- }
58+ [[maybe_unused]]
59+ inline void crelu16_to_16 (int16_t *input, int16_t *output, int size){
7960
80- template <int mask>
81- inline vec_int32 permute2x128_si256 (vec_int32 v1, vec_int32 v2) {
82- return _mm256_permute2x128_si256 (v1, v2, mask);
83- }
61+ assert (size % INT16_PER_REG == 0 );
8462
85- inline vec_int8 max_epi8 (vec_int8 v1, vec_int8 v2) {
86- return _mm256_max_epi8 (v1, v2);
87- }
63+ const vec_int16 zero = setzero_epi16 ();
64+ const vec_int16 qscale = set1_epi16 (255 );
8865
89- inline vec_int16 max_epi16 (vec_int16 v1, vec_int16 v2) {
90- return _mm256_max_epi16 (v1, v2);
66+ for (int i = 0 ; i < size; i += INT16_PER_REG){
67+ vec_int16 in = load_epi16 (&input[i]);
68+ vec_int16 out = min_epi16 (qscale, max_epi16 (in, zero));
69+ store_epi16 (&output[i], out);
9170 }
71+ }
9272
93- inline vec_int16 min_epi16 (vec_int16 v1, vec_int16 v2) {
94- return _mm256_min_epi16 (v1, v2);
95- }
73+ # ifdef USE_AVX2 // these functions use the AVX2 specific instructions hadd_epi32 and permute2x128_si256
74+ inline int32_t reduce1_epi32 (vec_int32 input){ // horizontal add 1 int32 avx register.
75+ input = hadd_epi32 (input, input);
9676
97- inline vec_int16 add_epi16 (vec_int16 v1, vec_int16 v2) {
98- return _mm256_add_epi16 (v1, v2);
99- }
77+ int32_t out_ptr[8 ];
78+ store_epi32 (out_ptr, input);
10079
101- inline vec_int32 add_epi32 (vec_int32 v1, vec_int32 v2) {
102- return _mm256_add_epi32 (v1, v2);
80+ return out_ptr[0 ] + out_ptr[1 ] + out_ptr[4 ] + out_ptr[5 ];
10381 }
10482
105- inline vec_int16 sub_epi16 (vec_int16 v1, vec_int16 v2) {
106- return _mm256_sub_epi16 (v1, v2);
107- }
83+ [[maybe_unused]]
84+ inline vec_int32 reduce8_epi32 (vec_int32* inputs){ // horizontal add 8 int32 avx registers.
85+ inputs[0 ] = hadd_epi32 (inputs[0 ], inputs[1 ]);
86+ inputs[2 ] = hadd_epi32 (inputs[2 ], inputs[3 ]);
87+ inputs[4 ] = hadd_epi32 (inputs[4 ], inputs[5 ]);
88+ inputs[6 ] = hadd_epi32 (inputs[6 ], inputs[7 ]);
10889
109- inline vec_int32 madd_epi16 (vec_int16 v1, vec_int16 v2) {
110- return _mm256_madd_epi16 (v1, v2);
111- }
112-
113- inline vec_int32 hadd_epi32 (vec_int32 v1, vec_int32 v2) {
114- return _mm256_hadd_epi32 (v1, v2);
115- }
90+ inputs[0 ] = hadd_epi32 (inputs[0 ], inputs[2 ]);
91+ inputs[4 ] = hadd_epi32 (inputs[4 ], inputs[6 ]);
11692
117- inline vec_int16 mullo_epi16 (vec_int16 v1, vec_int16 v2) {
118- return _mm256_mullo_epi16 (v1, v2);
93+ return add_epi32 (
94+ // swap lanes of the two registers
95+ permute2x128_si256<0b00110001 >(inputs[0 ], inputs[4 ]),
96+ permute2x128_si256<0b00100000 >(inputs[0 ], inputs[4 ])
97+ );
11998 }
120-
121- #else
122- #error "bread requires the AVX2 instruction set to run."
12399#endif
124100
125- constexpr int NUM_AVX_REGISTERS = 8 ;
126- constexpr int INT32_PER_REG = sizeof (vec_int32) / sizeof (int32_t );
127- constexpr int INT16_PER_REG = sizeof (vec_int16) / sizeof (int16_t );
128- constexpr int INT8_PER_REG = sizeof (vec_int8) / sizeof (int8_t );
129-
130- namespace NNUE_UTILS {
131-
132- void crelu32_to_8 (int32_t *input, int8_t *output, int size);
133- void crelu16_to_8 (int16_t *input, int8_t *output, int size);
134-
135- void crelu16_to_16 (int16_t *input, int16_t *output, int size);
136-
137- int32_t reduce1_epi32 (vec_int32& input); // horizontal add 1 int32 avx register.
138- vec_int32 reduce8_epi32 (vec_int32* inputs);
139-
140101}; // namespace NNUE_UTILS
0 commit comments