Skip to content

Commit 39c6264

Browse files
committed
add AVX512 support
passed LTC: Elo | 0.83 +- 2.14 (95%) SPRT | 40.0+0.40s Threads=1 Hash=128MB LLR | 2.96 (-2.94, 2.94) [-3.50, 0.50] Games | N: 29464 W: 6983 L: 6913 D: 15568 Penta | [194, 3488, 7282, 3590, 178] https://nonlinear.eu.pythonanywhere.com/test/1004/ passed STC: Elo | 2.38 +- 3.16 (95%) SPRT | 8.0+0.08s Threads=1 Hash=32MB LLR | 2.97 (-2.94, 2.94) [-3.50, 0.50] Games | N: 16208 W: 4174 L: 4063 D: 7971 Penta | [216, 1884, 3816, 1949, 239] https://nonlinear.eu.pythonanywhere.com/test/1003/ bench: 2539701
1 parent f8682ee commit 39c6264

5 files changed

Lines changed: 309 additions & 193 deletions

File tree

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ set(bread_SOURCE
3131
${bread_SRC}/sorted_move_gen.cpp
3232
${bread_SRC}/nnue/nnue.cpp
3333
${bread_SRC}/nnue/nnue_board.cpp
34-
${bread_SRC}/nnue/nnue_misc.cpp
3534
${bread_SRC}/transposition_table.cpp
3635
${bread_SRC}/core.cpp
3736
${bread_SRC}/benchmark.cpp

src/nnue/nnue.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,10 @@ int32_t run_L1(Accumulators& accumulators, Color stm, int bucket){
245245
in = min_epi16(qscale, max_epi16(in, zero));
246246

247247
vec_int16 weight_chunk = load_epi16(&l1_weights[bucket * L1_WEIGHTS_SIZE + i]);
248-
vec_int32 prod = madd_epi16(in, mullo_epi16(in, weight_chunk));
249248

250249
// madd pairs to int32 to avoid overflows in int16
250+
vec_int32 prod = madd_epi16(in, mullo_epi16(in, weight_chunk));
251+
251252
result = add_epi32(result, prod);
252253
}
253254

@@ -256,9 +257,10 @@ int32_t run_L1(Accumulators& accumulators, Color stm, int bucket){
256257
in = min_epi16(qscale, max_epi16(in, zero));
257258

258259
vec_int16 weight_chunk = load_epi16(&l1_weights[bucket * L1_WEIGHTS_SIZE + ACC_SIZE + i]);
259-
vec_int32 prod = madd_epi16(in, mullo_epi16(in, weight_chunk));
260260

261261
// madd pairs to int32 to avoid overflows in int16
262+
vec_int32 prod = madd_epi16(in, mullo_epi16(in, weight_chunk));
263+
262264
result = add_epi32(result, prod);
263265
}
264266

src/nnue/nnue_misc.cpp

Lines changed: 0 additions & 83 deletions
This file was deleted.

src/nnue/nnue_misc.hpp

Lines changed: 68 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,101 @@
1+
#pragma once
2+
13
#include <cstdint>
24
#include <cmath>
35
#include <chess.hpp>
6+
#include "simd.hpp"
47

5-
#if defined(__AVX2__)
6-
#include <immintrin.h>
7-
8-
using vec_int8 = __m256i;
9-
using vec_uint8 = __m256i;
10-
using vec_int16 = __m256i;
11-
using vec_uint16 = __m256i;
12-
using vec_int32 = __m256i;
13-
using vec_uint32 = __m256i;
14-
15-
inline vec_int8 setzero_epi8() {
16-
return _mm256_setzero_si256();
17-
}
18-
19-
inline vec_int16 setzero_epi16() {
20-
return _mm256_setzero_si256();
21-
}
22-
23-
inline vec_int32 setzero_epi32() {
24-
return _mm256_setzero_si256();
25-
}
8+
constexpr int NUM_AVX_REGISTERS = 8;
9+
constexpr int INT32_PER_REG = sizeof(vec_int32) / sizeof(int32_t);
10+
constexpr int INT16_PER_REG = sizeof(vec_int16) / sizeof(int16_t);
11+
constexpr int INT8_PER_REG = sizeof(vec_int8) / sizeof(int8_t);
2612

27-
inline vec_int16 set1_epi16(int i) {
28-
return _mm256_set1_epi16(i);
29-
}
13+
namespace NNUE_UTILS {
3014

31-
inline vec_int32 set1_epi32(int i) {
32-
return _mm256_set1_epi32(i);
33-
}
15+
#ifdef USE_AVX2 // these functions use the AVX2 specific instruction permute4x64_epi64
16+
[[maybe_unused]]
17+
inline void crelu32_to_8(int32_t *input, int8_t *output, int size){
3418

35-
inline vec_int8 load_epi8(int8_t* ptr) {
36-
return _mm256_loadu_si256((const __m256i*)ptr);
37-
}
19+
assert(size % INT8_PER_REG == 0);
3820

39-
inline vec_int16 load_epi16(int16_t* ptr) {
40-
return _mm256_loadu_si256((const __m256i*)ptr);
41-
}
21+
const int num_regs = size / INT8_PER_REG;
22+
const vec_int8 zero = setzero_epi8();
4223

43-
inline vec_int32 load_epi32(int32_t* ptr) {
44-
return _mm256_loadu_si256((const __m256i*)ptr);
45-
}
24+
for (int i = 0; i < num_regs; i++){
25+
vec_int32 in_1 = load_epi32(&input[(4*i)*INT32_PER_REG]);
26+
vec_int32 in_2 = load_epi32(&input[(4*i+1)*INT32_PER_REG]);
27+
vec_int32 in_3 = load_epi32(&input[(4*i+2)*INT32_PER_REG]);
28+
vec_int32 in_4 = load_epi32(&input[(4*i+3)*INT32_PER_REG]);
4629

47-
inline void store_epi8(int8_t* ptr, vec_int8 v) {
48-
_mm256_storeu_si256((__m256i*)ptr, v);
49-
}
30+
in_1 = permute4x64_epi64<0b10'00'11'01>(packs_epi32(in_1, in_2));
31+
in_3 = permute4x64_epi64<0b10'00'11'01>(packs_epi32(in_3, in_4));
5032

51-
inline void store_epi16(int16_t* ptr, vec_int16 v) {
52-
_mm256_storeu_si256((__m256i*)ptr, v);
33+
vec_int8 out = packs_epi16(in_1, in_3);
34+
out = max_epi8(out, zero); // packs saturates at 127, so only max is applied
35+
out = permute4x64_epi64<0b01'11'00'10>(out);
36+
store_epi8(&output[i*INT8_PER_REG], out);
37+
}
5338
}
5439

55-
inline void store_epi32(int32_t* ptr, vec_int32 v) {
56-
_mm256_storeu_si256((__m256i*)ptr, v);
57-
}
40+
[[maybe_unused]]
41+
inline void crelu16_to_8(int16_t *input, int8_t *output, int size){
5842

59-
inline vec_int8 packs_epi16(vec_int16 v1, vec_int16 v2) {
60-
return _mm256_packs_epi16(v1, v2);
61-
}
43+
assert(size % INT8_PER_REG == 0);
6244

63-
inline vec_int16 packs_epi32(vec_int32 v1, vec_int32 v2) {
64-
return _mm256_packs_epi32(v1, v2);
65-
}
45+
const int num_regs = size / INT8_PER_REG;
6646

67-
inline vec_int8 packus_epi16(vec_int16 v1, vec_int16 v2) {
68-
return _mm256_packus_epi16(v1, v2);
69-
}
70-
71-
inline vec_int16 packus_epi32(vec_int32 v1, vec_int32 v2) {
72-
return _mm256_packus_epi32(v1, v2);
47+
for (int i = 0; i < num_regs; i++){
48+
vec_int16 in_1 = load_epi16(&input[(2*i)*INT16_PER_REG]);
49+
vec_int16 in_2 = load_epi16(&input[(2*i+1)*INT16_PER_REG]);
50+
// packs sets negative values to 0 and saturates at 255, which effectively applies relu
51+
vec_int8 out = packus_epi16(in_1, in_2);
52+
out = permute4x64_epi64<0b11'01'10'00>(out); // undo the packus shuffle
53+
store_epi8(&output[i*INT8_PER_REG], out);
54+
}
7355
}
56+
#endif
7457

75-
template<int mask>
76-
inline vec_int32 permute4x64_epi64(vec_int32 v) {
77-
return _mm256_permute4x64_epi64(v, mask);
78-
}
58+
[[maybe_unused]]
59+
inline void crelu16_to_16(int16_t *input, int16_t *output, int size){
7960

80-
template<int mask>
81-
inline vec_int32 permute2x128_si256(vec_int32 v1, vec_int32 v2) {
82-
return _mm256_permute2x128_si256(v1, v2, mask);
83-
}
61+
assert(size % INT16_PER_REG == 0);
8462

85-
inline vec_int8 max_epi8(vec_int8 v1, vec_int8 v2) {
86-
return _mm256_max_epi8(v1, v2);
87-
}
63+
const vec_int16 zero = setzero_epi16();
64+
const vec_int16 qscale = set1_epi16(255);
8865

89-
inline vec_int16 max_epi16(vec_int16 v1, vec_int16 v2) {
90-
return _mm256_max_epi16(v1, v2);
66+
for (int i = 0; i < size; i += INT16_PER_REG){
67+
vec_int16 in = load_epi16(&input[i]);
68+
vec_int16 out = min_epi16(qscale, max_epi16(in, zero));
69+
store_epi16(&output[i], out);
9170
}
71+
}
9272

93-
inline vec_int16 min_epi16(vec_int16 v1, vec_int16 v2) {
94-
return _mm256_min_epi16(v1, v2);
95-
}
73+
#ifdef USE_AVX2 // these functions use the AVX2 specific instructions hadd_epi32 and permute2x128_si256
74+
inline int32_t reduce1_epi32(vec_int32 input){ // horizontal add 1 int32 avx register.
75+
input = hadd_epi32(input, input);
9676

97-
inline vec_int16 add_epi16(vec_int16 v1, vec_int16 v2) {
98-
return _mm256_add_epi16(v1, v2);
99-
}
77+
int32_t out_ptr[8];
78+
store_epi32(out_ptr, input);
10079

101-
inline vec_int32 add_epi32(vec_int32 v1, vec_int32 v2) {
102-
return _mm256_add_epi32(v1, v2);
80+
return out_ptr[0] + out_ptr[1] + out_ptr[4] + out_ptr[5];
10381
}
10482

105-
inline vec_int16 sub_epi16(vec_int16 v1, vec_int16 v2) {
106-
return _mm256_sub_epi16(v1, v2);
107-
}
83+
[[maybe_unused]]
84+
inline vec_int32 reduce8_epi32(vec_int32* inputs){ // horizontal add 8 int32 avx registers.
85+
inputs[0] = hadd_epi32(inputs[0], inputs[1]);
86+
inputs[2] = hadd_epi32(inputs[2], inputs[3]);
87+
inputs[4] = hadd_epi32(inputs[4], inputs[5]);
88+
inputs[6] = hadd_epi32(inputs[6], inputs[7]);
10889

109-
inline vec_int32 madd_epi16(vec_int16 v1, vec_int16 v2) {
110-
return _mm256_madd_epi16(v1, v2);
111-
}
112-
113-
inline vec_int32 hadd_epi32(vec_int32 v1, vec_int32 v2) {
114-
return _mm256_hadd_epi32(v1, v2);
115-
}
90+
inputs[0] = hadd_epi32(inputs[0], inputs[2]);
91+
inputs[4] = hadd_epi32(inputs[4], inputs[6]);
11692

117-
inline vec_int16 mullo_epi16(vec_int16 v1, vec_int16 v2) {
118-
return _mm256_mullo_epi16(v1, v2);
93+
return add_epi32(
94+
// swap lanes of the two registers
95+
permute2x128_si256<0b00110001>(inputs[0], inputs[4]),
96+
permute2x128_si256<0b00100000>(inputs[0], inputs[4])
97+
);
11998
}
120-
121-
#else
122-
#error "bread requires the AVX2 instruction set to run."
12399
#endif
124100

125-
constexpr int NUM_AVX_REGISTERS = 8;
126-
constexpr int INT32_PER_REG = sizeof(vec_int32) / sizeof(int32_t);
127-
constexpr int INT16_PER_REG = sizeof(vec_int16) / sizeof(int16_t);
128-
constexpr int INT8_PER_REG = sizeof(vec_int8) / sizeof(int8_t);
129-
130-
namespace NNUE_UTILS {
131-
132-
void crelu32_to_8(int32_t *input, int8_t *output, int size);
133-
void crelu16_to_8(int16_t *input, int8_t *output, int size);
134-
135-
void crelu16_to_16(int16_t *input, int16_t *output, int size);
136-
137-
int32_t reduce1_epi32(vec_int32& input); // horizontal add 1 int32 avx register.
138-
vec_int32 reduce8_epi32(vec_int32* inputs);
139-
140101
}; // namespace NNUE_UTILS

0 commit comments

Comments
 (0)