|
| 1 | +/* |
| 2 | +* Simd Library (http://ermig1979.github.io/Simd). |
| 3 | +* |
| 4 | +* Copyright (c) 2011-2024 Yermalayeu Ihar. |
| 5 | +* |
| 6 | +* Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | +* of this software and associated documentation files (the "Software"), to deal |
| 8 | +* in the Software without restriction, including without limitation the rights |
| 9 | +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | +* copies of the Software, and to permit persons to whom the Software is |
| 11 | +* furnished to do so, subject to the following conditions: |
| 12 | +* |
| 13 | +* The above copyright notice and this permission notice shall be included in |
| 14 | +* all copies or substantial portions of the Software. |
| 15 | +* |
| 16 | +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | +* SOFTWARE. |
| 23 | +*/ |
| 24 | +#include "Simd/SimdSynet.h" |
| 25 | +#include "Simd/SimdArray.h" |
| 26 | +#include "Simd/SimdMath.h" |
| 27 | +#include "Simd/SimdExtract.h" |
| 28 | +#include "Simd/SimdAvx2.h" |
| 29 | + |
| 30 | +namespace Simd |
| 31 | +{ |
| 32 | +#if defined(SIMD_AVX2_ENABLE) && defined(SIMD_SYNET_ENABLE) |
| 33 | + namespace Avx2 |
| 34 | + { |
| 35 | + void NormalizeNhwc16bV2(const uint16_t* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float * buf, uint16_t* dst) |
| 36 | + { |
| 37 | + float k = 1.0f / float(channels); |
| 38 | + size_t channelsF = AlignLo(channels, F), c; |
| 39 | + Array32f _buf; |
| 40 | + if (buf == NULL) |
| 41 | + { |
| 42 | + _buf.Resize(channels); |
| 43 | + buf = _buf.data; |
| 44 | + } |
| 45 | + for (size_t b = 0; b < batch; ++b) |
| 46 | + { |
| 47 | + for (size_t s = 0; s < spatial; ++s) |
| 48 | + { |
| 49 | + BFloat16ToFloat32(src, channels, buf); |
| 50 | + |
| 51 | + __m256 _sum = _mm256_setzero_ps(); |
| 52 | + for (c = 0; c < channelsF; c += F) |
| 53 | + _sum = _mm256_add_ps(_mm256_loadu_ps(buf + c), _sum); |
| 54 | + float sum = ExtractSum(_sum); |
| 55 | + for (; c < channels; ++c) |
| 56 | + sum += buf[c]; |
| 57 | + __m256 mean = _mm256_set1_ps(sum * k); |
| 58 | + for (c = 0; c < channelsF; c += F) |
| 59 | + _mm256_storeu_ps(buf + c, _mm256_sub_ps(_mm256_loadu_ps(buf + c), mean)); |
| 60 | + for (; c < channels; ++c) |
| 61 | + _mm_store_ss(buf + c, _mm_sub_ss(_mm_load_ss(buf + c), _mm256_castps256_ps128(mean))); |
| 62 | + |
| 63 | + __m256 _sqsum = _mm256_setzero_ps(); |
| 64 | + for (c = 0; c < channelsF; c += F) |
| 65 | + { |
| 66 | + __m256 _buf = _mm256_loadu_ps(buf + c); |
| 67 | + _sqsum = _mm256_fmadd_ps(_buf, _buf, _sqsum); |
| 68 | + } |
| 69 | + float sqsum = ExtractSum(_sqsum); |
| 70 | + for (; c < channels; ++c) |
| 71 | + sqsum += Simd::Square(buf[c]); |
| 72 | + __m256 norm = _mm256_set1_ps(1.0f / ::sqrt(sqsum * k + eps)); |
| 73 | + for (c = 0; c < channelsF; c += F) |
| 74 | + _mm256_storeu_ps(buf + c, _mm256_fmadd_ps(_mm256_mul_ps(_mm256_loadu_ps(buf + c), norm), _mm256_loadu_ps(scale + c), _mm256_loadu_ps(shift + c))); |
| 75 | + for (; c < channels; ++c) |
| 76 | + _mm_store_ss(buf + c, _mm_fmadd_ss(_mm_mul_ss(_mm_load_ss(buf + c), _mm256_castps256_ps128(norm)), _mm_load_ss(scale + c), _mm_load_ss(shift + c))); |
| 77 | + |
| 78 | + Float32ToBFloat16(buf, channels, dst); |
| 79 | + |
| 80 | + dst += channels; |
| 81 | + src += channels; |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + void SynetNormalizeLayerForward16bV2(const uint16_t* src, size_t batch, size_t channels, size_t spatial, |
| 87 | + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, uint16_t* dst) |
| 88 | + { |
| 89 | + if (format == SimdTensorFormatNhwc) |
| 90 | + NormalizeNhwc16bV2(src, batch, channels, spatial, scale, shift, *eps, buf, dst); |
| 91 | + else |
| 92 | + assert(0); |
| 93 | + } |
| 94 | + } |
| 95 | +#endif |
| 96 | +} |
0 commit comments