From bcb95b01124422ad818c70a7e3901b3e284f84ec Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 14:06:33 +0200 Subject: [PATCH 01/14] Add basic ARM64 architecture support for NEON, SVE, and SVE2. Just a skeleton, no implementations yet. --- CMakeLists.txt | 10 +- external/simde | 2 +- include/pernix/arm64/neon/compression.h | 59 +++++++ include/pernix/arm64/neon/decompression.h | 59 +++++++ include/pernix/arm64/neon/packing.h | 11 ++ include/pernix/arm64/neon/unpacking.h | 20 +++ include/pernix/arm64/sve/compression.h | 59 +++++++ include/pernix/arm64/sve/decompression.h | 59 +++++++ include/pernix/arm64/sve/packing.h | 11 ++ include/pernix/arm64/sve/unpacking.h | 11 ++ include/pernix/arm64/sve2/compression.h | 59 +++++++ include/pernix/arm64/sve2/decompression.h | 59 +++++++ include/pernix/arm64/sve2/packing.h | 11 ++ include/pernix/arm64/sve2/unpacking.h | 11 ++ include/pernix/detection.h | 19 +- include/pernix/pernix.h | 167 +++++++++++++++++- include/pernix/simd_compat.h | 8 + src/CMakeLists.txt | 48 ++++- src/arm64/neon/compression.cpp | 21 +++ src/arm64/neon/decompression.cpp | 21 +++ src/arm64/sve/compression.cpp | 21 +++ src/arm64/sve/decompression.cpp | 21 +++ src/arm64/sve2/compression.cpp | 21 +++ src/arm64/sve2/decompression.cpp | 21 +++ src/pernix.cpp | 112 +++++++++++- tests/arm64/neon/.gitkeep | 0 tests/arm64/sve/.gitkeep | 0 tests/arm64/sve2/.gitkeep | 0 .../compression_tests.cpp} | 2 +- .../decompression_tests.cpp} | 2 +- .../edge_tests.cpp} | 0 tests/include/testset.h | 2 +- .../avx2/compression_tests.cpp} | 2 +- .../avx2/decompression_tests.cpp} | 2 +- .../avx512vbmi/compression_tests.cpp} | 0 .../avx512vbmi/decompression_tests.cpp} | 2 +- .../bmi2/compression_tests.cpp} | 2 +- .../bmi2/decompression_tests.cpp} | 2 +- 38 files changed, 916 insertions(+), 21 deletions(-) create mode 100644 include/pernix/arm64/neon/compression.h create mode 100644 include/pernix/arm64/neon/decompression.h create mode 100644 include/pernix/arm64/neon/packing.h create mode 100644 include/pernix/arm64/neon/unpacking.h create mode 100644 include/pernix/arm64/sve/compression.h create mode 100644 include/pernix/arm64/sve/decompression.h create mode 100644 include/pernix/arm64/sve/packing.h create mode 100644 include/pernix/arm64/sve/unpacking.h create mode 100644 include/pernix/arm64/sve2/compression.h create mode 100644 include/pernix/arm64/sve2/decompression.h create mode 100644 include/pernix/arm64/sve2/packing.h create mode 100644 include/pernix/arm64/sve2/unpacking.h create mode 100644 src/arm64/neon/compression.cpp create mode 100644 src/arm64/neon/decompression.cpp create mode 100644 src/arm64/sve/compression.cpp create mode 100644 src/arm64/sve/decompression.cpp create mode 100644 src/arm64/sve2/compression.cpp create mode 100644 src/arm64/sve2/decompression.cpp create mode 100644 tests/arm64/neon/.gitkeep create mode 100644 tests/arm64/sve/.gitkeep create mode 100644 tests/arm64/sve2/.gitkeep rename tests/{compression/fallback_compression_tests.cpp => fallback/compression_tests.cpp} (97%) rename tests/{decompression/fallback_decompression_tests.cpp => fallback/decompression_tests.cpp} (97%) rename tests/{fallback_edge_tests.cpp => fallback/edge_tests.cpp} (100%) rename tests/{compression/avx2_compression_tests.cpp => x86/avx2/compression_tests.cpp} (97%) rename tests/{decompression/avx2_decompression_tests.cpp => x86/avx2/decompression_tests.cpp} (97%) rename tests/{compression/avx512vbmi_compression_tests.cpp => x86/avx512vbmi/compression_tests.cpp} (100%) rename tests/{decompression/avx512vbmi_decompression_tests.cpp => x86/avx512vbmi/decompression_tests.cpp} (97%) rename tests/{compression/bmi2_compression_tests.cpp => x86/bmi2/compression_tests.cpp} (97%) rename tests/{decompression/bmi2_decompression_tests.cpp => x86/bmi2/decompression_tests.cpp} (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt index edf6c78..3fa49be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,9 +13,17 @@ option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) +set(PERNIX_ARCH_BACKEND "AUTO" CACHE STRING "Pernix architecture backend (AUTO, FALLBACK, X86, ARM64_NEON, ARM64_SVE, ARM64_SVE2)") +set_property(CACHE PERNIX_ARCH_BACKEND PROPERTY STRINGS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) option(PERNIX_ENABLE_FORTRAN_BINDINGS "Build Fortran bindings for pernix" off) +string(TOUPPER "${PERNIX_ARCH_BACKEND}" PERNIX_ARCH_BACKEND) +set(PERNIX_VALID_ARCH_BACKENDS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) +if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) + message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") +endif () + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") if (PERNIX_USE_SIMDE) @@ -97,4 +105,4 @@ endif () if (PERNIX_ENABLE_TESTS) enable_testing() add_subdirectory(tests) -endif () \ No newline at end of file +endif () diff --git a/external/simde b/external/simde index 1747b24..1a1ca5e 160000 --- a/external/simde +++ b/external/simde @@ -1 +1 @@ -Subproject commit 1747b2482589fe894d49989159421da08c2a8bcd +Subproject commit 1a1ca5ee71518d8a115234dad1e2d871421953b7 diff --git a/include/pernix/arm64/neon/compression.h b/include/pernix/arm64/neon/compression.h new file mode 100644 index 0000000..65ea786 --- /dev/null +++ b/include/pernix/arm64/neon/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_NEON_COMPRESSION_H +#define PERNIX_ARM64_NEON_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool neon_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int neon_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int neon_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int neon_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int neon_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h new file mode 100644 index 0000000..0f8d79e --- /dev/null +++ b/include/pernix/arm64/neon/decompression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_NEON_DECOMPRESSION_H +#define PERNIX_ARM64_NEON_DECOMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool neon_decompression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_block(const uint8_t*, float_t, float_t*) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_block(const uint8_t*, double_t, double_t*) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/packing.h b/include/pernix/arm64/neon/packing.h new file mode 100644 index 0000000..c1c8119 --- /dev/null +++ b/include/pernix/arm64/neon/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_NEON_PACKING_H +#define PERNIX_ARM64_NEON_PACKING_H + +#include + +namespace pernix::arm64::neon::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_PACKING_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h new file mode 100644 index 0000000..32dc5de --- /dev/null +++ b/include/pernix/arm64/neon/unpacking.h @@ -0,0 +1,20 @@ +#ifndef PERNIX_ARM64_NEON_UNPACKING_H +#define PERNIX_ARM64_NEON_UNPACKING_H + +#include + +namespace pernix::arm64::neon::internal { + +template +inline constexpr bool unpacking_unimplemented_v = false; + +namespace b64 { + +} // namespace b64 + + +} // namespace pernix::arm64::neon::internal + + + +#endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve/compression.h b/include/pernix/arm64/sve/compression.h new file mode 100644 index 0000000..f8abfed --- /dev/null +++ b/include/pernix/arm64/sve/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE_COMPRESSION_H +#define PERNIX_ARM64_SVE_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int sve_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int sve_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int sve_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE_COMPRESSION_H diff --git a/include/pernix/arm64/sve/decompression.h b/include/pernix/arm64/sve/decompression.h new file mode 100644 index 0000000..784c0c8 --- /dev/null +++ b/include/pernix/arm64/sve/decompression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE_DECOMPRESSION_H +#define PERNIX_ARM64_SVE_DECOMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve_decompression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_block(const uint8_t*, float_t, float_t*) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_block(const uint8_t*, double_t, double_t*) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +int sve_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); +int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve/packing.h b/include/pernix/arm64/sve/packing.h new file mode 100644 index 0000000..fce21ca --- /dev/null +++ b/include/pernix/arm64/sve/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE_PACKING_H +#define PERNIX_ARM64_SVE_PACKING_H + +#include + +namespace pernix::arm64::sve::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::sve::internal + +#endif // PERNIX_ARM64_SVE_PACKING_H diff --git a/include/pernix/arm64/sve/unpacking.h b/include/pernix/arm64/sve/unpacking.h new file mode 100644 index 0000000..3de49ca --- /dev/null +++ b/include/pernix/arm64/sve/unpacking.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE_UNPACKING_H +#define PERNIX_ARM64_SVE_UNPACKING_H + +#include + +namespace pernix::arm64::sve::internal { +template +inline constexpr bool unpacking_unimplemented_v = false; +} // namespace pernix::arm64::sve::internal + +#endif // PERNIX_ARM64_SVE_UNPACKING_H diff --git a/include/pernix/arm64/sve2/compression.h b/include/pernix/arm64/sve2/compression.h new file mode 100644 index 0000000..4e4627d --- /dev/null +++ b/include/pernix/arm64/sve2/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE2_COMPRESSION_H +#define PERNIX_ARM64_SVE2_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve2_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve2_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int sve2_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int sve2_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int sve2_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE2_COMPRESSION_H diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h new file mode 100644 index 0000000..c27f08a --- /dev/null +++ b/include/pernix/arm64/sve2/decompression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE2_DECOMPRESSION_H +#define PERNIX_ARM64_SVE2_DECOMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve2_decompression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const uint8_t*, float_t, float_t*) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const uint8_t*, double_t, double_t*) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve2_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +int sve2_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +int sve2_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); +int sve2_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve2/packing.h b/include/pernix/arm64/sve2/packing.h new file mode 100644 index 0000000..789b4d7 --- /dev/null +++ b/include/pernix/arm64/sve2/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE2_PACKING_H +#define PERNIX_ARM64_SVE2_PACKING_H + +#include + +namespace pernix::arm64::sve2::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_PACKING_H diff --git a/include/pernix/arm64/sve2/unpacking.h b/include/pernix/arm64/sve2/unpacking.h new file mode 100644 index 0000000..d654b5e --- /dev/null +++ b/include/pernix/arm64/sve2/unpacking.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE2_UNPACKING_H +#define PERNIX_ARM64_SVE2_UNPACKING_H + +#include + +namespace pernix::arm64::sve2::internal { +template +inline constexpr bool unpacking_unimplemented_v = false; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/include/pernix/detection.h b/include/pernix/detection.h index edecb6c..fa9cd44 100644 --- a/include/pernix/detection.h +++ b/include/pernix/detection.h @@ -10,6 +10,19 @@ #define PERNIX_MACHINE_ID_V4 3 #define PERNIX_MACHINE_ID_V4_VBMI 4 +#if defined(PERNIX_BACKEND_ARM64_NEON) +#define PERNIX_ARM64_NEON_ENABLED +#endif + +#if defined(PERNIX_BACKEND_ARM64_SVE) +#define PERNIX_ARM64_SVE_ENABLED +#endif + +#if defined(PERNIX_BACKEND_ARM64_SVE2) +#define PERNIX_ARM64_SVE2_ENABLED +#endif + +#if defined(PERNIX_BACKEND_X86) // Map the compiler's enabled ISA set to the highest supported Pernix target level. #if (__SSE3__ && __SSE4_1__ && __SSE4_2__) #if (__AVX__ && __AVX2__ && __FMA__ && __BMI__ && __BMI2__) @@ -32,6 +45,10 @@ #define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC #endif +#else +#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC +#endif + // Feature-selection macros consumed by the public headers. #if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V2) #define PERNIX_SSE_ENABLED @@ -47,7 +64,7 @@ #define PERNIX_AVX512_VBMI_ENABLED #endif -#ifdef PERNIX_USE_SIMDE +#if defined(PERNIX_USE_SIMDE) && defined(PERNIX_BACKEND_X86) #define PERNIX_SSE_ENABLED #define PERNIX_AVX2_ENABLED #define PERNIX_BMI2_ENABLED diff --git a/include/pernix/pernix.h b/include/pernix/pernix.h index 55133bb..4998a60 100644 --- a/include/pernix/pernix.h +++ b/include/pernix/pernix.h @@ -5,7 +5,7 @@ // Include architecture-specific headers based on detected capabilities // AVX2 -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #include #include @@ -21,7 +21,22 @@ #include #endif // PERNIX_AVX512_VBMI_ENABLED -#endif // PERNIX_AVX2_ENABLED +#endif // PERNIX_BACKEND_X86 && PERNIX_AVX2_ENABLED + +#ifdef PERNIX_BACKEND_ARM64_NEON +#include +#include +#endif + +#ifdef PERNIX_BACKEND_ARM64_SVE +#include +#include +#endif + +#ifdef PERNIX_BACKEND_ARM64_SVE2 +#include +#include +#endif // Fallback (non-SIMD) implementations #include @@ -167,7 +182,7 @@ template int decompress_blocks(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); // Use the best available implementation based on detected CPU features at compile time. -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #ifdef PERNIX_AVX512_VBMI_ENABLED template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -265,6 +280,150 @@ int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, d return mm256_decompress_blocks_avx2(input, scale, output, blocks); } #endif +#elif defined(PERNIX_BACKEND_ARM64_NEON) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return neon_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return neon_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return neon_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return neon_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return neon_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return neon_decompress_blocks(input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return sve_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return sve_decompress_blocks(input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE2) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve2_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve2_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve2_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve2_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return sve2_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return sve2_decompress_blocks(input, scale, output, blocks); +} #else template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -420,4 +579,4 @@ int decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, } // namespace pernix #endif -#endif // PERNIX_H \ No newline at end of file +#endif // PERNIX_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index c95c1ee..8a66ee7 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -8,9 +8,15 @@ #define SIMDE_ENABLE_NATIVE_ALIASES #undef SIMDE_X86_AVX512FP16_NATIVE // #define SIMDE_NO_NATIVE +#if defined(PERNIX_BACKEND_X86) #include #include #include +#elif defined(PERNIX_BACKEND_ARM64_NEON) +#include +#elif defined(PERNIX_BACKEND_ARM64_SVE) || defined(PERNIX_BACKEND_ARM64_SVE2) +#include +#endif // #ifndef __mmask8 // typedef uint8_t __mmask8; @@ -27,6 +33,8 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include +#elif defined(__aarch64__) +#include #endif #ifndef __always_inline diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b7e8f23..f0c7e3e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,18 +9,46 @@ file(GLOB_RECURSE set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) -set(PERNIX_TARGET_IS_X86 OFF) -if (PERNIX_USE_SIMDE OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_TARGET_IS_X86 ON) +set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_SELECTED_ARCH_BACKEND "X86") + elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") + else () + set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") + endif () endif () +message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") -if (PERNIX_TARGET_IS_X86) +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES ./x86/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_X86_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_SOURCES + ./arm64/neon/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/neon/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_NEON_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_SOURCES + ./arm64/sve/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_SOURCES + ./arm64/sve2/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve2/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE2_SOURCES}) endif () add_library(pernix SHARED ${PERNIX_SOURCES}) @@ -37,6 +65,18 @@ if (PERNIX_USE_SIMDE) target_compile_definitions(pernix PUBLIC PERNIX_USE_SIMDE=1) endif () +target_compile_definitions(pernix PUBLIC "PERNIX_BACKEND_${PERNIX_SELECTED_ARCH_BACKEND}=1") + +if (PERNIX_DISABLE_BMI2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_BMI2=1) +endif () +if (PERNIX_DISABLE_AVX2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX2=1) +endif () +if (PERNIX_DISABLE_AVX512) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX512=1) +endif () + set_target_properties(pernix PROPERTIES LINKER_LANGUAGE CXX) configure_file( diff --git a/src/arm64/neon/compression.cpp b/src/arm64/neon/compression.cpp new file mode 100644 index 0000000..5968f79 --- /dev/null +++ b/src/arm64/neon/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int neon_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int neon_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int neon_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int neon_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp new file mode 100644 index 0000000..3cb7fc7 --- /dev/null +++ b/src/arm64/neon/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int neon_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int neon_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int neon_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int neon_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve/compression.cpp b/src/arm64/sve/compression.cpp new file mode 100644 index 0000000..e973183 --- /dev/null +++ b/src/arm64/sve/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int sve_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int sve_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int sve_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve/decompression.cpp b/src/arm64/sve/decompression.cpp new file mode 100644 index 0000000..c6d84be --- /dev/null +++ b/src/arm64/sve/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int sve_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int sve_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int sve_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve2/compression.cpp b/src/arm64/sve2/compression.cpp new file mode 100644 index 0000000..0a55f16 --- /dev/null +++ b/src/arm64/sve2/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve2_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int sve2_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int sve2_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int sve2_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp new file mode 100644 index 0000000..8d170f6 --- /dev/null +++ b/src/arm64/sve2/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve2_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int sve2_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int sve2_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int sve2_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/pernix.cpp b/src/pernix.cpp index 87ccf9d..94d9d14 100644 --- a/src/pernix.cpp +++ b/src/pernix.cpp @@ -6,7 +6,7 @@ extern "C" { #endif // Use the best available implementation based on detected CPU features at compile time -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #ifdef PERNIX_AVX512_VBMI_ENABLED int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { return mm512_compress_block_avx512vbmi(bit_width, input, scale, output); @@ -80,6 +80,114 @@ int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ i return mm256_decompress_blocks_f64_avx2(bit_width, input, scale, output, blocks); } #endif +#elif defined(PERNIX_BACKEND_ARM64_NEON) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return neon_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return neon_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return neon_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return neon_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return neon_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return neon_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return neon_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return sve_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return sve_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE2) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve2_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve2_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve2_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve2_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return sve2_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return sve2_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} #else int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { return compress_block_fallback(bit_width, input, scale, output); @@ -121,4 +229,4 @@ int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ i #ifdef __cplusplus } } // namespace pernix -#endif // __cplusplus \ No newline at end of file +#endif // __cplusplus diff --git a/tests/arm64/neon/.gitkeep b/tests/arm64/neon/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/arm64/sve/.gitkeep b/tests/arm64/sve/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/arm64/sve2/.gitkeep b/tests/arm64/sve2/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/compression/fallback_compression_tests.cpp b/tests/fallback/compression_tests.cpp similarity index 97% rename from tests/compression/fallback_compression_tests.cpp rename to tests/fallback/compression_tests.cpp index 9b50109..78249d7 100644 --- a/tests/compression/fallback_compression_tests.cpp +++ b/tests/fallback/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include TYPED_TEST(CompressionTest, FallbackCompressBlock) { diff --git a/tests/decompression/fallback_decompression_tests.cpp b/tests/fallback/decompression_tests.cpp similarity index 97% rename from tests/decompression/fallback_decompression_tests.cpp rename to tests/fallback/decompression_tests.cpp index 3c5dc1f..08c26c4 100644 --- a/tests/decompression/fallback_decompression_tests.cpp +++ b/tests/fallback/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { diff --git a/tests/fallback_edge_tests.cpp b/tests/fallback/edge_tests.cpp similarity index 100% rename from tests/fallback_edge_tests.cpp rename to tests/fallback/edge_tests.cpp diff --git a/tests/include/testset.h b/tests/include/testset.h index 6535957..5ef4fa1 100644 --- a/tests/include/testset.h +++ b/tests/include/testset.h @@ -1,7 +1,7 @@ #ifndef PERNIX_TESTSET_H #define PERNIX_TESTSET_H -#include <../../include/pernix/pernix.h> +#include #include #include diff --git a/tests/compression/avx2_compression_tests.cpp b/tests/x86/avx2/compression_tests.cpp similarity index 97% rename from tests/compression/avx2_compression_tests.cpp rename to tests/x86/avx2/compression_tests.cpp index 1c2892b..bd7f683 100644 --- a/tests/compression/avx2_compression_tests.cpp +++ b/tests/x86/avx2/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX2_ENABLED diff --git a/tests/decompression/avx2_decompression_tests.cpp b/tests/x86/avx2/decompression_tests.cpp similarity index 97% rename from tests/decompression/avx2_decompression_tests.cpp rename to tests/x86/avx2/decompression_tests.cpp index e0f039f..a6fc2c5 100644 --- a/tests/decompression/avx2_decompression_tests.cpp +++ b/tests/x86/avx2/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX2_ENABLED diff --git a/tests/compression/avx512vbmi_compression_tests.cpp b/tests/x86/avx512vbmi/compression_tests.cpp similarity index 100% rename from tests/compression/avx512vbmi_compression_tests.cpp rename to tests/x86/avx512vbmi/compression_tests.cpp diff --git a/tests/decompression/avx512vbmi_decompression_tests.cpp b/tests/x86/avx512vbmi/decompression_tests.cpp similarity index 97% rename from tests/decompression/avx512vbmi_decompression_tests.cpp rename to tests/x86/avx512vbmi/decompression_tests.cpp index 446443a..f44dd8d 100644 --- a/tests/decompression/avx512vbmi_decompression_tests.cpp +++ b/tests/x86/avx512vbmi/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/compression/bmi2_compression_tests.cpp b/tests/x86/bmi2/compression_tests.cpp similarity index 97% rename from tests/compression/bmi2_compression_tests.cpp rename to tests/x86/bmi2/compression_tests.cpp index 85d3cac..b7fc2fd 100644 --- a/tests/compression/bmi2_compression_tests.cpp +++ b/tests/x86/bmi2/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_BMI2_ENABLED diff --git a/tests/decompression/bmi2_decompression_tests.cpp b/tests/x86/bmi2/decompression_tests.cpp similarity index 97% rename from tests/decompression/bmi2_decompression_tests.cpp rename to tests/x86/bmi2/decompression_tests.cpp index 11a8efb..dd7efc1 100644 --- a/tests/decompression/bmi2_decompression_tests.cpp +++ b/tests/x86/bmi2/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_BMI2_ENABLED From 22f156223ebf5cbaee7d974e3518cc576bdaf958 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 17:09:03 +0200 Subject: [PATCH 02/14] Refactor AVX512 compression and decompression headers: improve include organization, address minor formatting inconsistencies, and move utility functions to a shared `utils` file for better maintainability. --- include/pernix/simd_compat.h | 23 +-------- include/pernix/x86/avx512vbmi/compression.h | 51 ++++++++++--------- include/pernix/x86/avx512vbmi/decompression.h | 49 +++++++++--------- include/pernix/x86/utils.h | 16 ++++++ 4 files changed, 70 insertions(+), 69 deletions(-) create mode 100644 include/pernix/x86/utils.h diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index 8a66ee7..e96ec30 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -10,8 +10,8 @@ // #define SIMDE_NO_NATIVE #if defined(PERNIX_BACKEND_X86) #include -#include #include +#include #elif defined(PERNIX_BACKEND_ARM64_NEON) #include #elif defined(PERNIX_BACKEND_ARM64_SVE) || defined(PERNIX_BACKEND_ARM64_SVE2) @@ -47,25 +47,4 @@ #endif #endif -template - requires(std::is_integral_v && sizeof(T) <= 8) -static constexpr T tail_mask(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - if (tail_bytes == 0u) { - return static_cast(0); - } - if (tail_bytes >= 64u) { - return static_cast(~uint64_t{0}); - } - const uint64_t mask = (uint64_t{1} << tail_bytes) - 1u; - return static_cast(mask); -} - -static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - return tail_bytes; -} - #endif // PERNIX_SIMD_COMPAT_H diff --git a/include/pernix/x86/avx512vbmi/compression.h b/include/pernix/x86/avx512vbmi/compression.h index bc5a375..e1cb4f0 100644 --- a/include/pernix/x86/avx512vbmi/compression.h +++ b/include/pernix/x86/avx512vbmi/compression.h @@ -1,13 +1,16 @@ #ifndef PERNIX_AVX512VBMI_COMPRESSION_H #define PERNIX_AVX512VBMI_COMPRESSION_H +#include #include -#include #include -#include +#include +#include #include +using namespace pernix::x86::internal; + namespace pernix { namespace internal { template @@ -132,7 +135,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -150,7 +153,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_2x128(converted1, converted2)); mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -162,7 +165,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -207,7 +210,7 @@ template const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_2x256(converted1, converted2)); mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -220,7 +223,7 @@ template const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(converted); mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -232,7 +235,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -269,7 +272,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(packed_input); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -281,14 +284,14 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } if constexpr (remaining_elements > 0) { const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); } @@ -344,7 +347,7 @@ template mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -370,7 +373,7 @@ template mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -387,7 +390,7 @@ template mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -447,7 +450,7 @@ template mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -465,7 +468,7 @@ template mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -477,7 +480,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -517,7 +520,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(make_m512i_from_2x256(quantized1, quantized2)); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -529,7 +532,7 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -543,7 +546,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -618,7 +621,7 @@ int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const fl for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } @@ -644,13 +647,13 @@ int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const d for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } return 0; } -} // namespace pernix +} // namespace pernix #ifdef __cplusplus namespace pernix { @@ -716,7 +719,7 @@ int mm512_compress_blocks_f64_avx512vbmi(uint8_t bit_width, const double_t* __re #ifdef __cplusplus } -} // namespace pernix +} // namespace pernix #endif #endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/include/pernix/x86/avx512vbmi/decompression.h index 61280c9..6320240 100644 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ b/include/pernix/x86/avx512vbmi/decompression.h @@ -1,13 +1,16 @@ #ifndef PERNIX_AVX512VBMI_DECOMPRESSION_H #define PERNIX_AVX512VBMI_DECOMPRESSION_H +#include #include -#include #include -#include +#include +#include #include +using namespace pernix::x86::internal; + namespace pernix { namespace internal { /** @@ -58,7 +61,7 @@ template _mm512_storeu_ps(output + 48, dequantized4); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } } @@ -76,7 +79,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -90,7 +93,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -159,7 +162,7 @@ template _mm512_storeu_pd(output + 56, dequantized8); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } if constexpr (iterations_32 > 0) { @@ -185,7 +188,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -202,7 +205,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -255,7 +258,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -269,7 +272,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -282,7 +285,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -333,7 +336,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -351,7 +354,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -365,7 +368,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -406,7 +409,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -419,7 +422,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -462,7 +465,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -477,7 +480,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -493,7 +496,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -569,7 +572,7 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -596,12 +599,12 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; } -} // namespace pernix +} // namespace pernix #ifdef __cplusplus namespace pernix { @@ -666,7 +669,7 @@ int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __r #ifdef __cplusplus } -} // namespace pernix +} // namespace pernix #endif #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/include/pernix/x86/utils.h b/include/pernix/x86/utils.h new file mode 100644 index 0000000..185e0a2 --- /dev/null +++ b/include/pernix/x86/utils.h @@ -0,0 +1,16 @@ +#ifndef PERNIX_X86_UTILS_H +#define PERNIX_X86_UTILS_H + +#include + +namespace pernix::x86::internal { + +static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; +} + +} // namespace pernix::x86::internal + +#endif // PERNIX_X86_UTILS_H From 37b3a725576985aa6db47107cddb7c539fc67c13 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 17:26:44 +0200 Subject: [PATCH 03/14] Replace SIMDe submodule with FetchContent for streamlined dependency management. --- .gitmodules | 3 --- CMakeLists.txt | 16 +++++++++++++++- external/simde | 1 - 3 files changed, 15 insertions(+), 5 deletions(-) delete mode 100644 .gitmodules delete mode 160000 external/simde diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index ae538ec..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "external/simde"] - path = external/simde - url = https://github.com/simd-everywhere/simde diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fa49be..8394abf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,21 @@ endif () list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") if (PERNIX_USE_SIMDE) - add_subdirectory(external/simde EXCLUDE_FROM_ALL) + include(FetchContent) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG master + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(simde) + + if (NOT TARGET simde::simde AND DEFINED simde_SOURCE_DIR AND EXISTS "${simde_SOURCE_DIR}/simde") + add_library(simde::simde INTERFACE IMPORTED GLOBAL) + target_include_directories(simde::simde INTERFACE "${simde_SOURCE_DIR}") + endif () endif () include(CTest) diff --git a/external/simde b/external/simde deleted file mode 160000 index 1a1ca5e..0000000 --- a/external/simde +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1a1ca5ee71518d8a115234dad1e2d871421953b7 From 11779ea1ff10cd012eada302fb5f70041bf17039 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 19:07:56 +0200 Subject: [PATCH 04/14] Integrate `CONFIGURE_DEPENDS` in source file globbing, improve CMake configuration with target aliases, install rules, and LTO support, enhance SIMDe handling with flexible provider selection and bundling, and refine compiler flag management for better compatibility. --- CMakeLists.txt | 90 ++++++++++++++++++++++++------------ include/pernix/simd_compat.h | 3 ++ src/CMakeLists.txt | 50 ++++++++++++++++++-- 3 files changed, 109 insertions(+), 34 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8394abf..b8c8208 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,8 @@ option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) +set(PERNIX_SIMDE_PROVIDER "AUTO" CACHE STRING "SIMDe provider when PERNIX_USE_SIMDE is enabled (AUTO, PACKAGE, FETCH)") +set_property(CACHE PERNIX_SIMDE_PROVIDER PROPERTY STRINGS AUTO PACKAGE FETCH) set(PERNIX_ARCH_BACKEND "AUTO" CACHE STRING "Pernix architecture backend (AUTO, FALLBACK, X86, ARM64_NEON, ARM64_SVE, ARM64_SVE2)") set_property(CACHE PERNIX_ARCH_BACKEND PROPERTY STRINGS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) @@ -24,24 +26,42 @@ if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") endif () +string(TOUPPER "${PERNIX_SIMDE_PROVIDER}" PERNIX_SIMDE_PROVIDER) +set(PERNIX_VALID_SIMDE_PROVIDERS AUTO PACKAGE FETCH) +if (NOT PERNIX_SIMDE_PROVIDER IN_LIST PERNIX_VALID_SIMDE_PROVIDERS) + message(FATAL_ERROR "Unsupported PERNIX_SIMDE_PROVIDER='${PERNIX_SIMDE_PROVIDER}'. Expected one of: ${PERNIX_VALID_SIMDE_PROVIDERS}") +endif () + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") +set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL OFF) if (PERNIX_USE_SIMDE) - include(FetchContent) - FetchContent_Declare( - simde - GIT_REPOSITORY https://github.com/simd-everywhere/simde.git - GIT_TAG master - GIT_SHALLOW TRUE - GIT_PROGRESS TRUE - EXCLUDE_FROM_ALL - ) - FetchContent_MakeAvailable(simde) + if (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "PACKAGE") + find_package(simde CONFIG QUIET) + endif () + + if (NOT TARGET simde::simde AND (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "FETCH")) + include(FetchContent) + set(SIMDE_TEST_CMAKE_PACKAGING OFF CACHE BOOL "Test SIMDe CMake packaging" FORCE) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG f3e8262173b7089db9a9d57a9ecef8dd07ad9c97 + GIT_PROGRESS TRUE + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(simde) + set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL ON) + endif () if (NOT TARGET simde::simde AND DEFINED simde_SOURCE_DIR AND EXISTS "${simde_SOURCE_DIR}/simde") add_library(simde::simde INTERFACE IMPORTED GLOBAL) target_include_directories(simde::simde INTERFACE "${simde_SOURCE_DIR}") endif () + + if (NOT TARGET simde::simde) + message(FATAL_ERROR "PERNIX_USE_SIMDE is enabled, but simde::simde was not found. Set PERNIX_SIMDE_PROVIDER=FETCH or install SIMDe's CMake package.") + endif () endif () include(CTest) @@ -62,28 +82,42 @@ else () endif () message(STATUS "Pernix version: ${VERSION}, normalized to ${NORMALIZED_VERSION}") -set(BENCHMARK_CXX_STANDARD 20) - -set(CMAKE_CXX_STANDARD ${BENCHMARK_CXX_STANDARD}) -set(CMAKE_CXX_STANDARD_REQUIRED YES) -set(CMAKE_CXX_EXTENSIONS OFF) - -include(AddCXXCompilerFlag) if (MSVC) message(FATAL_ERROR "MSVC compiler is not supported") else () - add_cxx_compiler_flag(-Wall) - add_cxx_compiler_flag(-Wextra) - add_cxx_compiler_flag(-Wshadow) - add_cxx_compiler_flag(-Wfloat-equal) - add_cxx_compiler_flag(-Wold-style-cast) - add_cxx_compiler_flag(-Wconversion) - add_cxx_compiler_flag(-fstrict-aliasing) - add_cxx_compiler_flag(-Wno-ignored-attributes) + include(CheckCXXCompilerFlag) + set(PERNIX_PRIVATE_COMPILE_OPTIONS) + foreach (PERNIX_CXX_FLAG + -Wall + -Wextra + -Wshadow + -Wfloat-equal + -Wold-style-cast + -Wconversion + -fstrict-aliasing + -Wno-ignored-attributes + ) + string(MAKE_C_IDENTIFIER "PERNIX_HAS_CXX_FLAG_${PERNIX_CXX_FLAG}" PERNIX_CXX_FLAG_VARIABLE) + check_cxx_compiler_flag("${PERNIX_CXX_FLAG}" "${PERNIX_CXX_FLAG_VARIABLE}") + if (${PERNIX_CXX_FLAG_VARIABLE}) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "${PERNIX_CXX_FLAG}") + else () + message(STATUS "Compiler flag not supported: ${PERNIX_CXX_FLAG}") + endif () + endforeach () if (PERNIX_ENABLE_LTO) - add_cxx_compiler_flag(-flto=auto) - add_cxx_compiler_flag(-Wno-lto-type-mismatch) + include(CheckIPOSupported) + check_ipo_supported(RESULT PERNIX_IPO_SUPPORTED OUTPUT PERNIX_IPO_ERROR) + if (NOT PERNIX_IPO_SUPPORTED) + message(FATAL_ERROR "PERNIX_ENABLE_LTO is enabled, but IPO/LTO is not supported: ${PERNIX_IPO_ERROR}") + endif () + + check_cxx_compiler_flag("-Wno-lto-type-mismatch" PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + if (PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "-Wno-lto-type-mismatch") + endif () + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") find_program(GCC_AR gcc-ar) if (GCC_AR) @@ -106,8 +140,6 @@ else () endif () endif () -include_directories(${PROJECT_SOURCE_DIR}/include) - add_subdirectory(src) if (PERNIX_ENABLE_FORTRAN_BINDINGS) diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index e96ec30..f96a84f 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -7,6 +7,9 @@ #if defined(PERNIX_USE_SIMDE) #define SIMDE_ENABLE_NATIVE_ALIASES #undef SIMDE_X86_AVX512FP16_NATIVE +#if defined(__clang__) +#define SIMDE_X86_AVX512BF16_NATIVE +#endif // #define SIMDE_NO_NATIVE #if defined(PERNIX_BACKEND_X86) #include diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f0c7e3e..8c79ab6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,9 @@ include(GNUInstallDirs) +include(CMakePackageConfigHelpers) file(GLOB_RECURSE PERNIX_COMMON_SOURCES + CONFIGURE_DEPENDS ./fallback/*.cpp ./pernix.cpp ${PROJECT_SOURCE_DIR}/include/pernix/*.h @@ -24,6 +26,7 @@ message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES + CONFIGURE_DEPENDS ./x86/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) @@ -31,6 +34,7 @@ if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") file(GLOB_RECURSE PERNIX_ARM64_NEON_SOURCES + CONFIGURE_DEPENDS ./arm64/neon/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/arm64/neon/*.h ) @@ -38,6 +42,7 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") file(GLOB_RECURSE PERNIX_ARM64_SVE_SOURCES + CONFIGURE_DEPENDS ./arm64/sve/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve/*.h ) @@ -45,6 +50,7 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") file(GLOB_RECURSE PERNIX_ARM64_SVE2_SOURCES + CONFIGURE_DEPENDS ./arm64/sve2/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve2/*.h ) @@ -52,13 +58,20 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") endif () add_library(pernix SHARED ${PERNIX_SOURCES}) +add_library(pernix::pernix ALIAS pernix) set_target_properties(pernix PROPERTIES OUTPUT_NAME "pernix" VERSION ${NORMALIZED_VERSION} ) +target_compile_features(pernix PUBLIC cxx_std_20) +target_compile_options(pernix PRIVATE ${PERNIX_PRIVATE_COMPILE_OPTIONS}) target_include_directories(pernix PUBLIC $ + $ ) +if (PERNIX_ENABLE_LTO) + set_target_properties(pernix PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) +endif () if (PERNIX_USE_SIMDE) target_link_libraries(pernix PUBLIC simde::simde) @@ -85,23 +98,50 @@ configure_file( if (PERNIX_ENABLE_INSTALL) install(TARGETS pernix + EXPORT pernixTargets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) - install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/pernix ${PROJECT_BINARY_DIR}/include/pernix + install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/pernix" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.*h" ) + if (PERNIX_USE_SIMDE AND PERNIX_BUNDLE_SIMDE_FOR_INSTALL) + install(DIRECTORY "${simde_SOURCE_DIR}/simde" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + FILES_MATCHING PATTERN "*.h" + ) + endif () + + configure_package_config_file( + "${PROJECT_SOURCE_DIR}/cmake/pernixConfig.cmake.in" + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + write_basic_package_version_file( + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + VERSION ${NORMALIZED_VERSION} + COMPATIBILITY SameMajorVersion + ) + install( + FILES + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + install( + EXPORT pernixTargets + FILE pernixTargets.cmake + NAMESPACE pernix:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) install( FILES ${PROJECT_BINARY_DIR}/pernix.pc DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig" ) - - add_custom_target(uninstall COMMAND xargs rm -vf < ${PROJECT_BINARY_DIR}/install_manifest.txt) endif () if (PERNIX_ENABLE_DOXYGEN) @@ -127,7 +167,7 @@ if (PERNIX_ENABLE_DOXYGEN) WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} COMMENT "Generating documentation with Doxygen." ) - if (BENCHMARK_ENABLE_INSTALL AND BENCHMARK_INSTALL_DOCS) + if (PERNIX_ENABLE_INSTALL AND PERNIX_INSTALL_DOCS) install(DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/html/" DESTINATION ${CMAKE_INSTALL_DOCDIR}) endif () From 17807c8ca2a62158d86e2154214e05ed0acb5736 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 20:11:11 +0200 Subject: [PATCH 05/14] Refactor inline attributes in AVX512 compression and decompression headers for consistency and clarity --- include/pernix/x86/avx512vbmi/compression.h | 62 ++++++++-------- include/pernix/x86/avx512vbmi/decompression.h | 72 +++++++++---------- include/pernix/x86/avx512vbmi/packing.h | 18 ++--- include/pernix/x86/avx512vbmi/tables.h | 66 +++++++++-------- 4 files changed, 108 insertions(+), 110 deletions(-) diff --git a/include/pernix/x86/avx512vbmi/compression.h b/include/pernix/x86/avx512vbmi/compression.h index e1cb4f0..9621e12 100644 --- a/include/pernix/x86/avx512vbmi/compression.h +++ b/include/pernix/x86/avx512vbmi/compression.h @@ -135,7 +135,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -153,7 +153,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_2x128(converted1, converted2)); mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -165,7 +165,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -183,8 +183,8 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -210,7 +210,7 @@ template const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_2x256(converted1, converted2)); mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -223,7 +223,7 @@ template const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(converted); mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -235,7 +235,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -253,8 +253,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -272,7 +272,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(packed_input); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -284,7 +284,7 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -301,8 +301,8 @@ template template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_64 = elements_per_block / 64; @@ -347,7 +347,7 @@ template mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -373,7 +373,7 @@ template mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -390,7 +390,7 @@ template mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -416,8 +416,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -450,7 +450,7 @@ template mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -468,7 +468,7 @@ template mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -480,7 +480,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -498,8 +498,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -520,7 +520,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(make_m512i_from_2x256(quantized1, quantized2)); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -532,7 +532,7 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -546,7 +546,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -621,7 +621,7 @@ int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const fl for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } @@ -647,13 +647,13 @@ int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const d for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } return 0; } -} // namespace pernix +} // namespace pernix #ifdef __cplusplus namespace pernix { @@ -719,7 +719,7 @@ int mm512_compress_blocks_f64_avx512vbmi(uint8_t bit_width, const double_t* __re #ifdef __cplusplus } -} // namespace pernix +} // namespace pernix #endif #endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/include/pernix/x86/avx512vbmi/decompression.h index 6320240..08abc35 100644 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ b/include/pernix/x86/avx512vbmi/decompression.h @@ -16,20 +16,20 @@ namespace internal { /** * @brief Dequantize sixteen integer values to floats. */ -[[gnu::always_inline]] inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { +__always_inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { const __m512 converted = _mm512_cvtepi32_ps(input); return _mm512_mul_ps(converted, scale); } -[[gnu::always_inline]] inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { +__always_inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { const __m512d converted = _mm512_cvtepi64_pd(input); return _mm512_mul_pd(converted, scale); } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const uint32_t iterations_64 = elements_per_block / 64; @@ -61,7 +61,7 @@ template _mm512_storeu_ps(output + 48, dequantized4); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } } @@ -79,7 +79,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -93,7 +93,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -112,8 +112,8 @@ template template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const uint32_t iterations_64 = elements_per_block / 64; @@ -162,7 +162,7 @@ template _mm512_storeu_pd(output + 56, dequantized8); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } if constexpr (iterations_32 > 0) { @@ -188,7 +188,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -205,7 +205,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -230,8 +230,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -258,7 +258,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -272,7 +272,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -285,7 +285,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -303,8 +303,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -336,7 +336,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -354,7 +354,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -368,7 +368,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -387,8 +387,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -409,7 +409,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -422,7 +422,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -439,8 +439,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -465,7 +465,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -480,7 +480,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -496,7 +496,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -572,7 +572,7 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -599,15 +599,15 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; } -} // namespace pernix +} // namespace pernix -#ifdef __cplusplus namespace pernix { +#ifdef __cplusplus extern "C" { #endif /** @@ -669,7 +669,7 @@ int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __r #ifdef __cplusplus } -} // namespace pernix #endif +} // namespace pernix #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/packing.h b/include/pernix/x86/avx512vbmi/packing.h index c9f9db9..ba3b132 100644 --- a/include/pernix/x86/avx512vbmi/packing.h +++ b/include/pernix/x86/avx512vbmi/packing.h @@ -11,7 +11,7 @@ namespace m128 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { +__always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -48,7 +48,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { +__always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -93,7 +93,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { +__always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { using tables = pack_tables_avx512_24; const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); @@ -117,7 +117,7 @@ namespace m256 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { +__always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -154,7 +154,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { +__always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -199,7 +199,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { +__always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { using tables = pack_tables_avx512_24; const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); @@ -223,7 +223,7 @@ namespace m512 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { +__always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -260,7 +260,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { +__always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -305,7 +305,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { +__always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { using tables = pack_tables_avx512_24; const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); diff --git a/include/pernix/x86/avx512vbmi/tables.h b/include/pernix/x86/avx512vbmi/tables.h index 9115625..4d66727 100644 --- a/include/pernix/x86/avx512vbmi/tables.h +++ b/include/pernix/x86/avx512vbmi/tables.h @@ -9,9 +9,8 @@ #include namespace pernix::internal { - template -[[gnu::always_inline]] static inline Vec load_table(const std::array& table) { +static __always_inline Vec load_table(const std::array& table) { static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); if constexpr (std::is_same_v) { return _mm512_load_si512(static_cast(table.data())); @@ -529,13 +528,13 @@ struct pack_tables_avx512_16 { // clang-format on } - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; template @@ -591,7 +590,7 @@ struct pack_tables_avx512_24 { return plan; } - static inline constexpr std::array word_plans = [] { + static constexpr std::array word_plans = [] { std::array plans{}; for (uint32_t i = 0; i < 16; ++i) { plans[i] = create_plan(i); @@ -600,7 +599,7 @@ struct pack_tables_avx512_24 { }(); template - [[gnu::always_inline]] static constexpr std::array make_table(Getter getter) { + static __always_inline constexpr std::array make_table(Getter getter) { std::array values{}; for (uint32_t i = 0; i < 16; ++i) { values[i] = getter(word_plans[i]); @@ -608,26 +607,26 @@ struct pack_tables_avx512_24 { return values; } - alignas(64) static inline constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); + alignas(64) static constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); - alignas(64) static inline constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); + alignas(64) static constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); - alignas(64) static inline constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); + alignas(64) static constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); - alignas(64) static inline constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); + alignas(64) static constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); - alignas(64) static inline constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); + alignas(64) static constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); - alignas(64) static inline constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); + alignas(64) static constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; template @@ -693,11 +692,11 @@ struct unpack_tables_avx512_8 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; template @@ -768,11 +767,11 @@ struct unpack_tables_avx512_16 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; template @@ -813,10 +812,9 @@ struct unpack_tables_avx512_24 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute() { return load_table(permute); } - [[gnu::always_inline]] static inline Vec get_shift() { return load_table(shift); } + static __always_inline Vec get_permute() { return load_table(permute); } + static __always_inline Vec get_shift() { return load_table(shift); } }; - -} // namespace pernix::internal +} // namespace pernix::internal #endif // PERNIX_AVX512VBMI_TABLES_H From 552cea37204a02b36085becafc9da57be3a74b64 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 20:18:14 +0200 Subject: [PATCH 06/14] Refactor ARM64 NEON and SVE compression/decompression headers: organize includes, introduce namespaced utility functions, add templates for bit-width handling, and update function signatures for consistency. --- include/pernix/arm64/neon/compression.h | 124 ++++++++++++++++++---- include/pernix/arm64/neon/decompression.h | 120 +++++++++++++++++---- include/pernix/arm64/neon/packing.h | 2 - include/pernix/arm64/neon/unpacking.h | 10 -- include/pernix/arm64/sve/compression.h | 124 ++++++++++++++++++---- include/pernix/arm64/sve/decompression.h | 120 +++++++++++++++++---- include/pernix/arm64/sve/packing.h | 2 - include/pernix/arm64/sve/unpacking.h | 2 - 8 files changed, 408 insertions(+), 96 deletions(-) diff --git a/include/pernix/arm64/neon/compression.h b/include/pernix/arm64/neon/compression.h index 65ea786..6e49348 100644 --- a/include/pernix/arm64/neon/compression.h +++ b/include/pernix/arm64/neon/compression.h @@ -2,58 +2,140 @@ #define PERNIX_ARM64_NEON_COMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::neon { namespace internal { -template -inline constexpr bool neon_compression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_block(const float_t*, float_t, uint8_t*) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_block(const double_t*, double_t, uint8_t*) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +int neon_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +int neon_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus extern "C" { #endif -int neon_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); -int neon_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); -int neon_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, +int neon_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int neon_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int neon_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int neon_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); + +int neon_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 0f8d79e..44744ce 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -2,42 +2,119 @@ #define PERNIX_ARM64_NEON_DECOMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::neon { namespace internal { -template -inline constexpr bool neon_decompression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_block(const uint8_t*, float_t, float_t*) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_block(const uint8_t*, double_t, double_t*) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus @@ -45,15 +122,20 @@ extern "C" { #endif int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); -int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + + +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); + +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/packing.h b/include/pernix/arm64/neon/packing.h index c1c8119..538b5a8 100644 --- a/include/pernix/arm64/neon/packing.h +++ b/include/pernix/arm64/neon/packing.h @@ -4,8 +4,6 @@ #include namespace pernix::arm64::neon::internal { -template -inline constexpr bool packing_unimplemented_v = false; } // namespace pernix::arm64::neon::internal #endif // PERNIX_ARM64_NEON_PACKING_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index 32dc5de..ea22b24 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -4,17 +4,7 @@ #include namespace pernix::arm64::neon::internal { - -template -inline constexpr bool unpacking_unimplemented_v = false; - -namespace b64 { - -} // namespace b64 - - } // namespace pernix::arm64::neon::internal - #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve/compression.h b/include/pernix/arm64/sve/compression.h index f8abfed..cf83ce0 100644 --- a/include/pernix/arm64/sve/compression.h +++ b/include/pernix/arm64/sve/compression.h @@ -2,58 +2,140 @@ #define PERNIX_ARM64_SVE_COMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::sve { namespace internal { -template -inline constexpr bool sve_compression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_block(const float_t*, float_t, uint8_t*) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_block(const double_t*, double_t, uint8_t*) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +int sve_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +int sve_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus extern "C" { #endif -int sve_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); -int sve_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); -int sve_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, +int sve_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int sve_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int sve_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int sve_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); + +int sve_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::sve #endif // PERNIX_ARM64_SVE_COMPRESSION_H diff --git a/include/pernix/arm64/sve/decompression.h b/include/pernix/arm64/sve/decompression.h index 784c0c8..052a3e4 100644 --- a/include/pernix/arm64/sve/decompression.h +++ b/include/pernix/arm64/sve/decompression.h @@ -2,42 +2,119 @@ #define PERNIX_ARM64_SVE_DECOMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::sve { namespace internal { -template -inline constexpr bool sve_decompression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_block(const uint8_t*, float_t, float_t*) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_block(const uint8_t*, double_t, double_t*) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +int sve_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +int sve_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus @@ -45,15 +122,20 @@ extern "C" { #endif int sve_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); -int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + + +int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + int sve_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); + +int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::sve #endif // PERNIX_ARM64_SVE_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve/packing.h b/include/pernix/arm64/sve/packing.h index fce21ca..ab57b4f 100644 --- a/include/pernix/arm64/sve/packing.h +++ b/include/pernix/arm64/sve/packing.h @@ -4,8 +4,6 @@ #include namespace pernix::arm64::sve::internal { -template -inline constexpr bool packing_unimplemented_v = false; } // namespace pernix::arm64::sve::internal #endif // PERNIX_ARM64_SVE_PACKING_H diff --git a/include/pernix/arm64/sve/unpacking.h b/include/pernix/arm64/sve/unpacking.h index 3de49ca..2565ab7 100644 --- a/include/pernix/arm64/sve/unpacking.h +++ b/include/pernix/arm64/sve/unpacking.h @@ -4,8 +4,6 @@ #include namespace pernix::arm64::sve::internal { -template -inline constexpr bool unpacking_unimplemented_v = false; } // namespace pernix::arm64::sve::internal #endif // PERNIX_ARM64_SVE_UNPACKING_H From 23553012600b6866bdda71d4a7c551310d29ea10 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 21:49:41 +0200 Subject: [PATCH 07/14] WIP: Implement NEON decompression functions with common utilities and templates --- include/pernix/arm64/neon/common.h | 106 ++++++++++++++++++++++ include/pernix/arm64/neon/decompression.h | 106 +++++++++++++++++++--- include/pernix/arm64/neon/unpacking.h | 20 +++- include/pernix/simd_compat.h | 2 +- 4 files changed, 219 insertions(+), 15 deletions(-) create mode 100644 include/pernix/arm64/neon/common.h diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h new file mode 100644 index 0000000..6efa777 --- /dev/null +++ b/include/pernix/arm64/neon/common.h @@ -0,0 +1,106 @@ +#ifndef PERNIX_ARM64_NEON_COMMON_H +#define PERNIX_ARM64_NEON_COMMON_H + +#include +#include + +namespace pernix::arm64::neon::internal { +static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; +} + +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x2_t(const int8x16_t& input) { + const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); + const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); + + return {{ + vmovl_s16(vget_low_s16(s16_lo)), + vmovl_s16(vget_high_s16(s16_lo)), + vmovl_s16(vget_low_s16(s16_hi)), + vmovl_s16(vget_high_s16(s16_hi)), + }}; +} + +__always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, const float32x4_t& scale) { + return {{ + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[2]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[3]), scale), + }}; +} + +__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_elements) { + uint8_t buffer[16] = {0}; + std::memcpy(buffer, input, tail_elements * sizeof(uint8_t)); + return vld1q_u8(buffer); +} + +__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_elements) { + uint16_t buffer[8] = {0}; + std::memcpy(buffer, input, tail_elements * sizeof(uint16_t)); + return vld1q_u16(buffer); +} + +__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_elements) { + uint32_t buffer[4] = {0}; + std::memcpy(buffer, input, tail_elements * sizeof(uint32_t)); + return vld1q_u32(buffer); +} + +__always_inline float32x4_t neon_load_tail_elements_f32(const uint8_t* input, const uint32_t tail_elements) { + float32_t buffer[4] = {0.0f}; + std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); + return vld1q_f32(buffer); +} + +__always_inline float64x2_t neon_load_tail_elements_f64(const uint8_t* input, const uint32_t tail_elements) { + float64_t buffer[2] = {0.0}; + std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); + return vld1q_f64(buffer); +} + +__always_inline void neon_store_tail_elements_int8(uint8_t* output, const uint8x16x4_t& data, const uint32_t tail_elements) { + uint8_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u8(buffer + i * 16, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint8_t)); +} + +__always_inline void neon_store_tail_elements_int16(uint16_t* output, const uint16x8x4_t& data, const uint32_t tail_elements) { + uint16_t buffer[8 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u16(buffer + i * 8, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint16_t)); +} + +__always_inline void neon_store_tail_elements_int32(uint32_t* output, const uint32x4x4_t& data, const uint32_t tail_elements) { + uint32_t buffer[4 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x4_t& data, const uint32_t tail_elements) { + float32_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { + float64_t buffer[8 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} +} // namespace pernix::arm64::neon::internal + +#endif //PERNIX_ARM64_NEON_COMMON_H \ No newline at end of file diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 44744ce..90ed78b 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -13,48 +14,129 @@ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x2_t(unpacked); + const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x2_t(tail_unpacked); + const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; } template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + for (uint32_t i = 0; i < iterations_8; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + for (uint32_t i = 0; i < iterations_4; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + for (uint32_t i = 0; i < iterations_8; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + for (uint32_t i = 0; i < iterations_4; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_2 = elements_per_block / 2; + constexpr uint32_t remaining_elements = elements_per_block - iterations_2 * 2; + + for (uint32_t i = 0; i < iterations_2; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } } // namespace internal diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index ea22b24..c1ae78f 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -3,8 +3,24 @@ #include -namespace pernix::arm64::neon::internal { -} // namespace pernix::arm64::neon::internal +namespace pernix::arm64::neon::internal::b128 { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline int8x16_t neon_unpack_epi8_1to8(const int8x16_t& input) { + return input; +} +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline int16x8_t neon_unpack_epi8_9to16(const int16x8_t& input) { + return input; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline int32x4_t neon_unpack_epi8_17to24(const int32x4_t& input) { + return input; +} +} // namespace pernix::arm64::neon::internal::b128 #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index f96a84f..07d5110 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -36,7 +36,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include -#elif defined(__aarch64__) +#elif defined(__aarch64__) || defined(__arm64ec__) #include #endif From 50d14a6dcdd647bf46e6ed0bc37e03a57926c342 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Mon, 25 May 2026 18:37:03 +0200 Subject: [PATCH 08/14] WIP: implement NEON decompression functions --- tests/arm64/neon/.gitkeep => .clangd | 0 CMakeLists.txt | 12 ++ include/pernix/arm64/neon/common.h | 46 ++++- include/pernix/arm64/neon/decompression.h | 78 +++++++- include/pernix/arm64/neon/unpacking.h | 81 ++++++++- include/pernix/arm64/tables.h | 212 ++++++++++++++++++++++ src/CMakeLists.txt | 12 -- tests/CMakeLists.txt | 42 ++++- tests/arm64/neon/decompression_tests.cpp | 38 ++++ 9 files changed, 487 insertions(+), 34 deletions(-) rename tests/arm64/neon/.gitkeep => .clangd (100%) create mode 100644 include/pernix/arm64/tables.h create mode 100644 tests/arm64/neon/decompression_tests.cpp diff --git a/tests/arm64/neon/.gitkeep b/.clangd similarity index 100% rename from tests/arm64/neon/.gitkeep rename to .clangd diff --git a/CMakeLists.txt b/CMakeLists.txt index b8c8208..1cbead6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,18 @@ if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") endif () +set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_SELECTED_ARCH_BACKEND "X86") + elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") + else () + set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") + endif () +endif () +message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") + string(TOUPPER "${PERNIX_SIMDE_PROVIDER}" PERNIX_SIMDE_PROVIDER) set(PERNIX_VALID_SIMDE_PROVIDERS AUTO PACKAGE FETCH) if (NOT PERNIX_SIMDE_PROVIDER IN_LIST PERNIX_VALID_SIMDE_PROVIDERS) diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h index 6efa777..f843908 100644 --- a/include/pernix/arm64/neon/common.h +++ b/include/pernix/arm64/neon/common.h @@ -11,7 +11,7 @@ static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t rem return tail_bytes; } -__always_inline int32x4x4_t neon_convert_int8x16_int32x4x2_t(const int8x16_t& input) { +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x4(const int8x16_t& input) { const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); @@ -23,6 +23,13 @@ __always_inline int32x4x4_t neon_convert_int8x16_int32x4x2_t(const int8x16_t& in }}; } +__always_inline int32x4x2_t neon_convert_int16x8_int32x4x2(const int16x8_t& input) { + return {{ + vmovl_s16(vget_low_s16(input)), + vmovl_s16(vget_high_s16(input)), + }}; +} + __always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, const float32x4_t& scale) { return {{ vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), @@ -32,21 +39,32 @@ __always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, co }}; } -__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_elements) { +__always_inline float32x4x2_t neon_dequantize_epi32(const int32x4x2_t& input, const float32x4_t& scale) { + return {{ + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + }}; +} + +__always_inline float32x4_t neon_dequantize_epi32(const int32x4_t& input, const float32x4_t& scale) { + return vmulq_f32(vcvtq_f32_s32(input), scale); +} + +__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_bytes_count) { uint8_t buffer[16] = {0}; - std::memcpy(buffer, input, tail_elements * sizeof(uint8_t)); + std::memcpy(buffer, input, tail_bytes_count); return vld1q_u8(buffer); } -__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_elements) { +__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_bytes_count) { uint16_t buffer[8] = {0}; - std::memcpy(buffer, input, tail_elements * sizeof(uint16_t)); + std::memcpy(buffer, input, tail_bytes_count); return vld1q_u16(buffer); } -__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_elements) { +__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_bytes_count) { uint32_t buffer[4] = {0}; - std::memcpy(buffer, input, tail_elements * sizeof(uint32_t)); + std::memcpy(buffer, input, tail_bytes_count); return vld1q_u32(buffer); } @@ -94,6 +112,20 @@ __always_inline void neon_store_tail_elements_f32(float32_t* output, const float std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); } +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x2_t& data, const uint32_t tail_elements) { + float32_t buffer[8 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4_t& data, const uint32_t tail_elements) { + float32_t buffer[4]; + vst1q_f32(buffer, data); + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + __always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { float64_t buffer[8 * 4]; for (uint32_t i = 0; i < 4; ++i) { diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 90ed78b..6c46f1a 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -25,7 +25,7 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input const uint8x16_t source = vld1q_u8(input); const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); - const int32x4x4_t converted = neon_convert_int8x16_int32x4x2_t(unpacked); + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); for (uint32_t j = 0; j < 4; ++j) { @@ -40,7 +40,7 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); - const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x2_t(tail_unpacked); + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); @@ -58,13 +58,34 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + const float32x4_t scale_v = vdupq_n_f32(scale); + for (uint32_t i = 0; i < iterations_8; ++i) { - static_assert(true, "Not yet implemented"); + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); } + + return 0; } template @@ -76,13 +97,52 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp constexpr uint32_t iterations_4 = elements_per_block / 4; constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + const float32x4_t scale_v = vdupq_n_f32(scale); + for (uint32_t i = 0; i < iterations_4; ++i) { - static_assert(true, "Not yet implemented"); + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + unpacked = b128::neon_unpack_epi8_17to24(source); + } + } + + const float32x4_t dequantized = neon_dequantize_epi32(unpacked, scale_v); + + vst1q_f32(output, dequantized); + + output += 4; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } + + const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); } + + return 0; } template @@ -101,6 +161,8 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input if constexpr (remaining_elements > 0) { static_assert(true, "Not yet implemented"); } + + return 0; } template @@ -119,6 +181,8 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu if constexpr (remaining_elements > 0) { static_assert(true, "Not yet implemented"); } + + return 0; } template @@ -137,6 +201,8 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp if constexpr (remaining_elements > 0) { static_assert(true, "Not yet implemented"); } + + return 0; } } // namespace internal diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index c1ae78f..663eb9d 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -2,24 +2,91 @@ #define PERNIX_ARM64_NEON_UNPACKING_H #include +#include + +using namespace pernix::arm64::internal; namespace pernix::arm64::neon::internal::b128 { template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline int8x16_t neon_unpack_epi8_1to8(const int8x16_t& input) { - return input; +__always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { + if constexpr (BIT_WIDTH == 8) { + return vreinterpretq_s8_u8(input); + } else { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + + uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute2.data())); + + shifted = vorrq_u8(shifted, vshlq_u8(permuted2_u8, vld1q_s8(tables::shift2.data()))); + } + + constexpr int shift = 8 - BIT_WIDTH; + shifted = vshlq_n_u8(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s8(vreinterpretq_s8_u8(shifted), vdupq_n_s8(-shift)); + } else { + return vreinterpretq_s8_u8(vshlq_u8(shifted, vdupq_n_s8(-shift))); + } + } } template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline int16x8_t neon_unpack_epi8_9to16(const int16x8_t& input) { - return input; +__always_inline int16x8_t neon_unpack_epi8_9to16(const uint16x8_t& input) { + if constexpr (BIT_WIDTH == 16) { + return vreinterpretq_s16_u16(input); + } else { + using tables = table_unpacking; + + const uint8x16_t input_u8 = vreinterpretq_u8_u16(input); + + const uint8x16_t permuted1_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute1.data())); + + uint16x8_t shifted = vshlq_u16(vreinterpretq_u16_u8(permuted1_u8), vld1q_s16(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute2.data())); + + const uint16x8_t shifted2 = vshlq_u16(vreinterpretq_u16_u8(permuted2_u8), vld1q_s16(tables::shift2.data())); + + shifted = vorrq_u16(shifted, shifted2); + } + + constexpr int shift = 16 - BIT_WIDTH; + shifted = vshlq_n_u16(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s16(vreinterpretq_s16_u16(shifted), vdupq_n_s16(-shift)); + } else { + return vreinterpretq_s16_u16(vshlq_u16(shifted, vdupq_n_s16(-shift))); + } + } } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline int32x4_t neon_unpack_epi8_17to24(const int32x4_t& input) { - return input; +__always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { + using tables = table_unpacking; + + const uint8x16_t input_8 = vreinterpretq_u8_u32(input); + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input_8, vld1q_u8(tables::permute.data())); + + const uint32x4_t value = vshlq_u32(vreinterpretq_u32_u8(permuted_u8), vld1q_s32(tables::shift.data())); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); + } } } // namespace pernix::arm64::neon::internal::b128 diff --git a/include/pernix/arm64/tables.h b/include/pernix/arm64/tables.h new file mode 100644 index 0000000..233bbc4 --- /dev/null +++ b/include/pernix/arm64/tables.h @@ -0,0 +1,212 @@ +#ifndef PERNIX_ARM64_TABLES_H +#define PERNIX_ARM64_TABLES_H + +#include +#include +#include + +namespace pernix::arm64::internal { +namespace detail { +inline constexpr std::size_t neon_vector_width = 128; +inline constexpr uint8_t inactive_lane = 0xff; + +template +constexpr bool table_indices_are_valid(const std::array& table) { + for (const uint8_t index : table) { + if (index != inactive_lane && index >= Elements) { + return false; + } + } + + return true; +} + +template +constexpr std::array make_primary_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t base = entry * lane_bytes; + + for (std::size_t lane_byte = 0; lane_byte < lane_bytes; ++lane_byte) { + table[base + lane_byte] = static_cast(first_byte + lane_byte); + } + } + + return table; +} + +template +constexpr std::array make_spill_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t base = entry * lane_bytes; + + if (bit_offset + BIT_WIDTH > LANE_BITS) { + table[base] = static_cast(first_byte + lane_bytes); + } + } + + return table; +} + +template +constexpr std::array make_shift_right() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + + table[entry] = -static_cast(bit_offset); + } + + return table; +} + +template +constexpr std::array make_shift_left_for_spill() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + const bool spills = bit_offset + BIT_WIDTH > LANE_BITS; + + table[entry] = spills ? static_cast(LANE_BITS - bit_offset) : 0; + } + + return table; +} + +template +constexpr std::array make_contiguous_permute_32() { + static_assert(ELEMENTS % 4 == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / 4; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + const std::size_t bit_end = bit_start + BIT_WIDTH - 1; + const std::size_t first_byte = bit_start / 8; + const std::size_t last_byte = bit_end / 8; + const std::size_t base = entry * 4; + + for (std::size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); + } + } + + return table; +} + +template +constexpr std::array make_shift_right_32() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + + table[entry] = -static_cast(bit_start % 8u); + } + + return table; +} +} // namespace detail + +template +struct table_unpacking; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && VECTOR_WIDTH == detail::neon_vector_width) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 8; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 16); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && VECTOR_WIDTH == detail::neon_vector_width) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 16; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = + detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 8); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && VECTOR_WIDTH == detail::neon_vector_width && START_BIT_OFFSET < 8) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 32; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute = + detail::make_contiguous_permute_32(); + alignas(64) static constexpr std::array shift = + detail::make_shift_right_32(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 4); + static_assert(detail::table_indices_are_valid(permute)); +}; +} // namespace pernix::arm64::internal + +#endif // PERNIX_ARM64_TABLES_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8c79ab6..57d2f0a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,18 +11,6 @@ file(GLOB_RECURSE set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) -set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") -if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") - if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_SELECTED_ARCH_BACKEND "X86") - elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") - set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") - else () - set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") - endif () -endif () -message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") - if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 370ec3b..b5ffe57 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,12 +2,50 @@ find_package(PkgConfig) pkg_search_module(GTEST REQUIRED gtest) include(CheckCXXCompilerFlag) +file(GLOB + PERNIX_ROOT_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp +) + file(GLOB_RECURSE - SOURCE_FILES + PERNIX_FALLBACK_TEST_SOURCES CONFIGURE_DEPENDS - *.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fallback/*.cpp ) +set(SOURCE_FILES ${PERNIX_ROOT_TEST_SOURCES} ${PERNIX_FALLBACK_TEST_SOURCES}) + +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") + file(GLOB_RECURSE + PERNIX_X86_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/x86/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_X86_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/neon/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_NEON_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve2/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE2_TEST_SOURCES}) +endif () + file(GLOB_RECURSE HEADER_FILES CONFIGURE_DEPENDS diff --git a/tests/arm64/neon/decompression_tests.cpp b/tests/arm64/neon/decompression_tests.cpp new file mode 100644 index 0000000..69229be --- /dev/null +++ b/tests/arm64/neon/decompression_tests.cpp @@ -0,0 +1,38 @@ +#include +#include + +#ifdef PERNIX_BACKEND_ARM64_NEON + +using namespace pernix::arm64::neon; + +TYPED_TEST(DecompressionTest, NeonDecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + neon_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +TYPED_TEST(DecompressionTest64, NeonDecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + neon_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +#endif \ No newline at end of file From 9eca7046bce1aa7be90f45420e698a4984f62af2 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 26 May 2026 00:04:55 +0200 Subject: [PATCH 09/14] WIP: implement NEON decompression functions --- include/pernix/arm64/neon/common.h | 68 ++++++++- include/pernix/arm64/neon/decompression.h | 162 +++++++++++++++------- include/pernix/arm64/neon/unpacking.h | 11 +- include/pernix/arm64/tables.h | 4 +- 4 files changed, 189 insertions(+), 56 deletions(-) diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h index f843908..8e517fa 100644 --- a/include/pernix/arm64/neon/common.h +++ b/include/pernix/arm64/neon/common.h @@ -2,9 +2,14 @@ #define PERNIX_ARM64_NEON_COMMON_H #include + #include namespace pernix::arm64::neon::internal { +struct float64x2x8_t { + float64x2_t val[8]; +}; + static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { const uint32_t tail_bits = remaining_elements * bit_width; const uint32_t tail_bytes = (tail_bits + 7u) / 8u; @@ -50,6 +55,47 @@ __always_inline float32x4_t neon_dequantize_epi32(const int32x4_t& input, const return vmulq_f32(vcvtq_f32_s32(input), scale); } +__always_inline float64x2_t neon_dequantize_epi32_f64(const int32x2_t& input, const float64x2_t& scale) { + return vmulq_f64(vcvtq_f64_s64(vmovl_s32(input)), scale); +} + +__always_inline float64x2x2_t neon_dequantize_epi32_f64(const int32x4_t& input, const float64x2_t& scale) { + return {{ + neon_dequantize_epi32_f64(vget_low_s32(input), scale), + neon_dequantize_epi32_f64(vget_high_s32(input), scale), + }}; +} + +__always_inline float64x2x4_t neon_dequantize_epi32_f64(const int32x4x2_t& input, const float64x2_t& scale) { + const float64x2x2_t dequantized_low = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized_high = neon_dequantize_epi32_f64(input.val[1], scale); + + return {{ + dequantized_low.val[0], + dequantized_low.val[1], + dequantized_high.val[0], + dequantized_high.val[1], + }}; +} + +__always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t& input, const float64x2_t& scale) { + const float64x2x2_t dequantized0 = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized1 = neon_dequantize_epi32_f64(input.val[1], scale); + const float64x2x2_t dequantized2 = neon_dequantize_epi32_f64(input.val[2], scale); + const float64x2x2_t dequantized3 = neon_dequantize_epi32_f64(input.val[3], scale); + + return {{ + dequantized0.val[0], + dequantized0.val[1], + dequantized1.val[0], + dequantized1.val[1], + dequantized2.val[0], + dequantized2.val[1], + dequantized3.val[0], + dequantized3.val[1], + }}; +} + __always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_bytes_count) { uint8_t buffer[16] = {0}; std::memcpy(buffer, input, tail_bytes_count); @@ -127,12 +173,28 @@ __always_inline void neon_store_tail_elements_f32(float32_t* output, const float } __always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { - float64_t buffer[8 * 4]; + float64_t buffer[2 * 4]; for (uint32_t i = 0; i < 4; ++i) { vst1q_f64(buffer + i * 2, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); } -} // namespace pernix::arm64::neon::internal -#endif //PERNIX_ARM64_NEON_COMMON_H \ No newline at end of file +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x2_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x8_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 8]; + for (uint32_t i = 0; i < 8; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_COMMON_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 6c46f1a..cfc051e 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -1,9 +1,9 @@ #ifndef PERNIX_ARM64_NEON_DECOMPRESSION_H #define PERNIX_ARM64_NEON_DECOMPRESSION_H -#include -#include #include +#include +#include #include #include @@ -12,8 +12,7 @@ namespace pernix::arm64::neon { namespace internal { template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -51,8 +50,7 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; @@ -90,8 +88,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_4 = elements_per_block / 4; @@ -123,12 +120,12 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp } if constexpr (remaining_elements > 0) { - constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; - const uint8_t* tail_input = input + tail_bit_start / 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; - const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); int32x4_t tail_unpacked; if constexpr (tail_bit_offset == 0) { @@ -147,19 +144,37 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; - for (uint32_t i = 0; i < iterations_8; ++i) { - static_assert(true, "Not yet implemented"); + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); + const float64x2x8_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 8; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); + const float64x2x8_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); } return 0; @@ -167,19 +182,37 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_4 = elements_per_block / 4; - constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; - for (uint32_t i = 0; i < iterations_4; ++i) { - static_assert(true, "Not yet implemented"); + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_8; ++i) { + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); } return 0; @@ -187,29 +220,65 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_2 = elements_per_block / 2; - constexpr uint32_t remaining_elements = elements_per_block - iterations_2 * 2; + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_4; ++i) { + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + unpacked = b128::neon_unpack_epi8_17to24(source); + } + } + + const float64x2x2_t dequantized = neon_dequantize_epi32_f64(unpacked, scale_v); - for (uint32_t i = 0; i < iterations_2; ++i) { - static_assert(true, "Not yet implemented"); + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } + + const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); } return 0; } -} // namespace internal +} // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return internal::neon_decompress_block_1to8(input, scale, output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { @@ -222,8 +291,7 @@ __always_inline int neon_decompress_block(const uint8_t* __restrict__ input, con template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return internal::neon_decompress_block_1to8(input, scale, output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { @@ -236,14 +304,13 @@ __always_inline int neon_decompress_block(const uint8_t* __restrict__ input, con template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { +int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { const uint8_t* block_input = input; float_t* block_output = output; for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -252,14 +319,13 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { +int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { const uint8_t* block_input = input; double_t* block_output = output; for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; @@ -271,19 +337,17 @@ extern "C" { int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix::arm64::neon +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index 663eb9d..b1f8585 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -1,8 +1,8 @@ #ifndef PERNIX_ARM64_NEON_UNPACKING_H #define PERNIX_ARM64_NEON_UNPACKING_H -#include #include +#include using namespace pernix::arm64::internal; @@ -12,6 +12,13 @@ template __always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { if constexpr (BIT_WIDTH == 8) { return vreinterpretq_s8_u8(input); + } else if constexpr (BIT_WIDTH == 1) { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + const uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + return vreinterpretq_s8_u8(vandq_u8(shifted, vdupq_n_u8(1))); } else { using tables = table_unpacking; @@ -88,6 +95,6 @@ __always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); } } -} // namespace pernix::arm64::neon::internal::b128 +} // namespace pernix::arm64::neon::internal::b128 #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/tables.h b/include/pernix/arm64/tables.h index 233bbc4..60e1dfe 100644 --- a/include/pernix/arm64/tables.h +++ b/include/pernix/arm64/tables.h @@ -134,7 +134,7 @@ constexpr std::array make_shift_right_32() { return table; } -} // namespace detail +} // namespace detail template struct table_unpacking; @@ -207,6 +207,6 @@ struct table_unpacking { static_assert(SHIFT_ELEMENTS == 4); static_assert(detail::table_indices_are_valid(permute)); }; -} // namespace pernix::arm64::internal +} // namespace pernix::arm64::internal #endif // PERNIX_ARM64_TABLES_H From af90ba8ac32167d689c220feb0dca14314affdf8 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Thu, 28 May 2026 13:22:46 +0200 Subject: [PATCH 10/14] WIP: Implement SVE2 decompression functions and update related headers for ARM64 --- include/pernix/arm64/neon/decompression.h | 36 +-- include/pernix/arm64/{ => neon}/tables.h | 26 +- include/pernix/arm64/neon/unpacking.h | 10 +- include/pernix/arm64/sve2/decompression.h | 346 ++++++++++++++++++++-- include/pernix/arm64/sve2/tables.h | 143 +++++++++ include/pernix/arm64/sve2/unpacking.h | 87 +++++- include/pernix/simd_compat.h | 5 + src/arm64/neon/decompression.cpp | 140 ++++++++- src/arm64/sve2/decompression.cpp | 76 ++++- tests/arm64/sve2/.gitkeep | 0 tests/arm64/sve2/decompression_tests.cpp | 38 +++ 11 files changed, 833 insertions(+), 74 deletions(-) rename include/pernix/arm64/{ => neon}/tables.h (94%) create mode 100644 include/pernix/arm64/sve2/tables.h delete mode 100644 tests/arm64/sve2/.gitkeep create mode 100644 tests/arm64/sve2/decompression_tests.cpp diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index cfc051e..583948f 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -60,7 +60,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu for (uint32_t i = 0; i < iterations_8; ++i) { const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); - const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); @@ -75,7 +75,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu if constexpr (remaining_elements > 0) { const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); @@ -103,12 +103,12 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t unpacked; if constexpr (BIT_WIDTH % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { if (i % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } } @@ -129,9 +129,9 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t tail_unpacked; if constexpr (tail_bit_offset == 0) { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } else { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); @@ -192,7 +192,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu for (uint32_t i = 0; i < iterations_8; ++i) { const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); - const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); @@ -207,7 +207,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu if constexpr (remaining_elements > 0) { const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); @@ -235,12 +235,12 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t unpacked; if constexpr (BIT_WIDTH % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { if (i % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } } @@ -262,9 +262,9 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t tail_unpacked; if constexpr (tail_bit_offset == 0) { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } else { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); @@ -274,7 +274,7 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp return 0; } -} // namespace internal +} // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -310,7 +310,7 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scal for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -325,7 +325,7 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t sca for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; @@ -348,6 +348,6 @@ int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ in #ifdef __cplusplus } #endif -} // namespace pernix::arm64::neon +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/tables.h b/include/pernix/arm64/neon/tables.h similarity index 94% rename from include/pernix/arm64/tables.h rename to include/pernix/arm64/neon/tables.h index 60e1dfe..d085551 100644 --- a/include/pernix/arm64/tables.h +++ b/include/pernix/arm64/neon/tables.h @@ -1,24 +1,21 @@ -#ifndef PERNIX_ARM64_TABLES_H -#define PERNIX_ARM64_TABLES_H +#ifndef PERNIX_ARM64_NEON_TABLES_H +#define PERNIX_ARM64_NEON_TABLES_H +#include #include #include #include -namespace pernix::arm64::internal { +namespace pernix::arm64::neon::internal { namespace detail { inline constexpr std::size_t neon_vector_width = 128; inline constexpr uint8_t inactive_lane = 0xff; template constexpr bool table_indices_are_valid(const std::array& table) { - for (const uint8_t index : table) { - if (index != inactive_lane && index >= Elements) { - return false; - } - } - - return true; + return std::ranges::all_of(table, [](const uint8_t index) { + return index == inactive_lane || index < Elements; + }); } template @@ -134,7 +131,7 @@ constexpr std::array make_shift_right_32() { return table; } -} // namespace detail +} // namespace detail template struct table_unpacking; @@ -207,6 +204,9 @@ struct table_unpacking { static_assert(SHIFT_ELEMENTS == 4); static_assert(detail::table_indices_are_valid(permute)); }; -} // namespace pernix::arm64::internal -#endif // PERNIX_ARM64_TABLES_H +template +struct table_packing; +} // namespace pernix::arm64::internal + +#endif // PERNIX_ARM64_NEON_TABLES_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index b1f8585..6ac0e20 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -1,10 +1,10 @@ #ifndef PERNIX_ARM64_NEON_UNPACKING_H #define PERNIX_ARM64_NEON_UNPACKING_H -#include +#include #include -using namespace pernix::arm64::internal; +using namespace pernix::arm64::neon::internal; namespace pernix::arm64::neon::internal::b128 { template @@ -45,7 +45,7 @@ __always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline int16x8_t neon_unpack_epi8_9to16(const uint16x8_t& input) { +__always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t& input) { if constexpr (BIT_WIDTH == 16) { return vreinterpretq_s16_u16(input); } else { @@ -78,7 +78,7 @@ __always_inline int16x8_t neon_unpack_epi8_9to16(const uint16x8_t& input) { template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { +__always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t& input) { using tables = table_unpacking; const uint8x16_t input_8 = vreinterpretq_u8_u32(input); @@ -95,6 +95,6 @@ __always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); } } -} // namespace pernix::arm64::neon::internal::b128 +} // namespace pernix::arm64::neon::internal::b128 #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h index c27f08a..198a8fc 100644 --- a/include/pernix/arm64/sve2/decompression.h +++ b/include/pernix/arm64/sve2/decompression.h @@ -1,43 +1,352 @@ #ifndef PERNIX_ARM64_SVE2_DECOMPRESSION_H #define PERNIX_ARM64_SVE2_DECOMPRESSION_H +#include +#include #include +#include #include #include +#include -namespace pernix { +namespace pernix::arm64::sve2 { namespace internal { -template -inline constexpr bool sve2_decompression_unimplemented_v = false; -} // namespace internal +template +[[nodiscard]] __always_inline constexpr uint32_t packed_bytes(const uint32_t elements) { + return (elements * BIT_WIDTH + 7) / 8; +} + +[[nodiscard]] __always_inline svuint8_t sve2_load_packed_bytes(const uint8_t* __restrict__ input, const uint32_t bytes) { + const svbool_t pg = svwhilelt_b8(uint64_t{0}, static_cast(bytes)); + return svld1_u8(pg, input); +} + +template +__always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sb_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1ub_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sh_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1uh_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + const svbool_t pg = svwhilelt_b32(uint64_t{0}, static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, values), scale_v); + } else { + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, svreinterpret_u32_s32(values)), scale_v); + } + + svst1_f32(pg, output, dequantized); +} template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_block(const uint8_t*, float_t, float_t*) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__, double_t, double_t* __restrict__) { return -1; } template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_block(const uint8_t*, double_t, double_t*) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__, double_t, double_t* __restrict__) { return -1; } +} // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); - return -1; +int sve2_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(input, scale, output); + } } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); - return -1; +int sve2_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(input, scale, output); + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + if constexpr (BIT_WIDTH > 8) { + return -1; + } else { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + if constexpr (BIT_WIDTH > 8) { + return -1; + } else { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } } #ifdef __cplusplus @@ -45,15 +354,18 @@ extern "C" { #endif int sve2_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + int sve2_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + int sve2_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); + int sve2_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::sve2 #endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve2/tables.h b/include/pernix/arm64/sve2/tables.h new file mode 100644 index 0000000..813e042 --- /dev/null +++ b/include/pernix/arm64/sve2/tables.h @@ -0,0 +1,143 @@ +#ifndef PERNIX_ARM64_SVE2_TABLES_H +#define PERNIX_ARM64_SVE2_TABLES_H + +#include + +#include +#include + +namespace pernix::arm64::sve2::internal { +template +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svbool_t pg_b8() { return svptrue_b8(); } + + static svbool_t pg_b16() { return svptrue_b16(); } + + static svbool_t pg_b32() { return svptrue_b32(); } +}; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) / 8u); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t spill_permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) / 8u + 1u); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t shift() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) % 8u); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t spill_shift() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast(8u - ((lane * BIT_WIDTH) % 8u)); + } + + return svld1_u8(svptrue_b8(), table.data()); + } +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t element = lane / 2u; + const uint32_t byte = lane % 2u; + const uint32_t first = (element * BIT_WIDTH) / 8u; + + table[lane] = static_cast(first + byte); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t spill_permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t element = lane / 2u; + const uint32_t byte = lane % 2u; + const uint32_t first = (element * BIT_WIDTH) / 8u; + + table[lane] = static_cast(first + 2u + byte); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint16_t shift() { + std::vector table(svcnth()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) % 8u); + } + + return svld1_u16(svptrue_b16(), table.data()); + } + + static svuint16_t spill_shift() { + std::vector table(svcnth()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t bit_offset = (lane * BIT_WIDTH) % 8u; + table[lane] = bit_offset + BIT_WIDTH > 16u ? static_cast(16u - bit_offset) : uint16_t{16}; + } + + return svld1_u16(svptrue_b16(), table.data()); + } +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t element = lane / 4u; + const uint32_t byte = lane % 4u; + const uint32_t first = (START_BIT_OFFSET + element * BIT_WIDTH) / 8u; + + table[lane] = static_cast(first + byte); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint32_t shift() { + std::vector table(svcntw()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = (START_BIT_OFFSET + lane * BIT_WIDTH) % 8u; + } + + return svld1_u32(svptrue_b32(), table.data()); + } +}; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/include/pernix/arm64/sve2/unpacking.h b/include/pernix/arm64/sve2/unpacking.h index d654b5e..326901f 100644 --- a/include/pernix/arm64/sve2/unpacking.h +++ b/include/pernix/arm64/sve2/unpacking.h @@ -3,9 +3,90 @@ #include +#include "tables.h" + namespace pernix::arm64::sve2::internal { -template -inline constexpr bool unpacking_unimplemented_v = false; -} // namespace pernix::arm64::sve2::internal +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, + const svuint8_t spill_permute, const svuint8_t spill_shift) { + if constexpr (BIT_WIDTH == 8) { + return svreinterpret_s8(input); + } else { + const svbool_t pg = svptrue_b8(); + + const svuint8_t permuted = svtbl_u8(input, permute); + svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); + const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); + unpacked = svorr_u8_x(pg, unpacked, spill_shifted); + } + + if constexpr (BIT_WIDTH == 1) { + unpacked = svand_n_u8_x(pg, unpacked, 1); + return svreinterpret_s8(unpacked); + } else { + constexpr int sign_shift = 8 - BIT_WIDTH; + + unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); + } else { + return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); + } + } + } +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, const svuint16_t shift, + const svuint8_t spill_permute, const svuint16_t spill_shift) { + if constexpr (BIT_WIDTH == 16) { + return svreinterpret_s16(input); + } else { + const svbool_t pg = svptrue_b16(); + + const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); + svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); + const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), spill_shift); + shifted = svorr_u16_x(pg, shifted, spill_shifted); + } + + constexpr int sign_shift = 16 - BIT_WIDTH; + shifted = svlsl_n_u16_x(pg, shifted, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); + } else { + return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); + } + } +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { + using table = table_unpacking; + + const svbool_t pg = svptrue_b32(); + const svuint8_t permuted = svtbl_u8(input, table::permute()); + const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + } +} +} // namespace pernix::arm64::sve2::internal #endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index 07d5110..509eb13 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -37,8 +37,13 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include #elif defined(__aarch64__) || defined(__arm64ec__) +#ifdef __ARM_FEATURE_SVE +#include +#endif +#ifdef __ARM_NEON #include #endif +#endif #ifndef __always_inline #if defined(__GNUC__) || defined(__clang__) diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp index 3cb7fc7..a89f763 100644 --- a/src/arm64/neon/decompression.cpp +++ b/src/arm64/neon/decompression.cpp @@ -2,20 +2,142 @@ namespace pernix { extern "C" { -int neon_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { - return -1; +#define PERNIX_NEON_DECOMPRESS_BLOCK_CASE(N) \ + case N: \ + return arm64::neon::neon_decompress_block(input, scale, output); + +#define PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(N) \ + case N: \ + return arm64::neon::neon_decompress_blocks(input, scale, output, blocks); + +int neon_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } } -int neon_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { - return -1; +int neon_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } } -int neon_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { - return -1; +int neon_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } } -int neon_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { - return -1; +int neon_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output, const uint32_t blocks) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } } + +#undef PERNIX_NEON_DECOMPRESS_BLOCK_CASE +#undef PERNIX_NEON_DECOMPRESS_BLOCKS_CASE } -} // namespace pernix +} // namespace pernix diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp index 8d170f6..66bbd97 100644 --- a/src/arm64/sve2/decompression.cpp +++ b/src/arm64/sve2/decompression.cpp @@ -2,20 +2,78 @@ namespace pernix { extern "C" { -int sve2_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { - return -1; +#define PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(N) \ + case N: \ + return arm64::sve2::sve2_decompress_block(input, scale, output); + +#define PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(N) \ + case N: \ + return arm64::sve2::sve2_decompress_blocks(input, scale, output, blocks); + +int sve2_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + default: + return -1; + } } -int sve2_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { - return -1; +int sve2_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + default: + return -1; + } } -int sve2_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { - return -1; +int sve2_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + default: + return -1; + } } -int sve2_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { - return -1; +int sve2_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output, const uint32_t blocks) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + default: + return -1; + } } + +#undef PERNIX_SVE2_DECOMPRESS_BLOCK_CASE +#undef PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE } -} // namespace pernix +} // namespace pernix diff --git a/tests/arm64/sve2/.gitkeep b/tests/arm64/sve2/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/tests/arm64/sve2/decompression_tests.cpp b/tests/arm64/sve2/decompression_tests.cpp new file mode 100644 index 0000000..82cb11f --- /dev/null +++ b/tests/arm64/sve2/decompression_tests.cpp @@ -0,0 +1,38 @@ +#include +#include + +#ifdef PERNIX_BACKEND_ARM64_SVE2 + +using namespace pernix::arm64::sve2; + +TYPED_TEST(DecompressionTest, SVE2DecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + sve2_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +TYPED_TEST(DecompressionTest64, SVE2DecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + sve2_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +#endif \ No newline at end of file From 0f0c1782b66c6687948714a6a53f6183b633d972 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Thu, 28 May 2026 21:32:10 +0200 Subject: [PATCH 11/14] WIP: Implement SVE2 decompression functions --- include/pernix/arm64/sve2/decompression.h | 122 +++++++++++++++++----- 1 file changed, 94 insertions(+), 28 deletions(-) diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h index 198a8fc..2128ff2 100644 --- a/include/pernix/arm64/sve2/decompression.h +++ b/include/pernix/arm64/sve2/decompression.h @@ -122,6 +122,22 @@ __always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfl svst1_f32(pg, output, dequantized); } +template +__always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntw()); + + svst1_s32(svptrue_b32(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -238,8 +254,39 @@ __always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ inp template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__, double_t, double_t* __restrict__) { - return -1; +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; } template @@ -282,8 +329,35 @@ __always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ inpu template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__, double_t, double_t* __restrict__) { - return -1; +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; } } // namespace internal @@ -314,39 +388,31 @@ int sve2_decompress_block(const uint8_t* __restrict__ input, const double_t scal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - if constexpr (BIT_WIDTH > 8) { - return -1; - } else { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve2_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + const uint8_t* block_input = input; + float_t* block_output = output; - return 0; + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - if constexpr (BIT_WIDTH > 8) { - return -1; - } else { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve2_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + const uint8_t* block_input = input; + double_t* block_output = output; - return 0; + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } + + return 0; } #ifdef __cplusplus From 95c6475d8092f96dd3417c376bcda8e03ba9e94e Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Thu, 28 May 2026 21:59:45 +0200 Subject: [PATCH 12/14] WIP: Extend SVE2 decompression functions --- include/pernix/arm64/sve2/tables.h | 115 +++++++++++------------------ src/arm64/sve2/decompression.cpp | 64 ++++++++++++++++ 2 files changed, 109 insertions(+), 70 deletions(-) diff --git a/include/pernix/arm64/sve2/tables.h b/include/pernix/arm64/sve2/tables.h index 813e042..897fa9b 100644 --- a/include/pernix/arm64/sve2/tables.h +++ b/include/pernix/arm64/sve2/tables.h @@ -4,7 +4,6 @@ #include #include -#include namespace pernix::arm64::sve2::internal { template @@ -24,39 +23,23 @@ struct table_unpacking { static constexpr uint8_t bit_width = BIT_WIDTH; static svuint8_t permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) / 8u); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); } static svuint8_t spill_permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) / 8u + 1u); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 1); } static svuint8_t shift() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) % 8u); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); } static svuint8_t spill_shift() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast(8u - ((lane * BIT_WIDTH) % 8u)); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svsub_u8_x(pg, svdup_n_u8(8), shift()); } }; @@ -66,48 +49,39 @@ struct table_unpacking { static constexpr uint8_t bit_width = BIT_WIDTH; static svuint8_t permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t element = lane / 2u; - const uint32_t byte = lane % 2u; - const uint32_t first = (element * BIT_WIDTH) / 8u; - - table[lane] = static_cast(first + byte); + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); + const svuint8_t byte = svand_n_u8_x(pg, lane, 1); + + svuint8_t first; + if constexpr (BIT_WIDTH == 16) { + first = svlsl_n_u8_x(pg, elem, 1); + } else { + constexpr uint8_t extra_bits = BIT_WIDTH - 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); + first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); } - return svld1_u8(svptrue_b8(), table.data()); + return svadd_u8_x(pg, first, byte); } static svuint8_t spill_permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t element = lane / 2u; - const uint32_t byte = lane % 2u; - const uint32_t first = (element * BIT_WIDTH) / 8u; - - table[lane] = static_cast(first + 2u + byte); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 2); } static svuint16_t shift() { - std::vector table(svcnth()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) % 8u); - } - - return svld1_u16(svptrue_b16(), table.data()); + const svbool_t pg = svptrue_b16(); + return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); } static svuint16_t spill_shift() { - std::vector table(svcnth()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t bit_offset = (lane * BIT_WIDTH) % 8u; - table[lane] = bit_offset + BIT_WIDTH > 16u ? static_cast(16u - bit_offset) : uint16_t{16}; - } - - return svld1_u16(svptrue_b16(), table.data()); + const svbool_t pg = svptrue_b16(); + const svuint16_t bit_shift = shift(); + const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); + return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); } }; @@ -117,25 +91,26 @@ struct table_unpacking { static constexpr uint8_t bit_width = BIT_WIDTH; static svuint8_t permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t element = lane / 4u; - const uint32_t byte = lane % 4u; - const uint32_t first = (START_BIT_OFFSET + element * BIT_WIDTH) / 8u; - - table[lane] = static_cast(first + byte); + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); + const svuint8_t byte = svand_n_u8_x(pg, lane, 3); + + svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); + if constexpr (BIT_WIDTH % 8u != 0) { + constexpr uint8_t extra_bits = BIT_WIDTH % 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low_bits = + svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); + first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); } - return svld1_u8(svptrue_b8(), table.data()); + return svadd_u8_x(pg, first, byte); } static svuint32_t shift() { - std::vector table(svcntw()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = (START_BIT_OFFSET + lane * BIT_WIDTH) % 8u; - } - - return svld1_u32(svptrue_b32(), table.data()); + const svbool_t pg = svptrue_b32(); + return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), START_BIT_OFFSET), 7); } }; } // namespace pernix::arm64::sve2::internal diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp index 66bbd97..8429e97 100644 --- a/src/arm64/sve2/decompression.cpp +++ b/src/arm64/sve2/decompression.cpp @@ -20,6 +20,22 @@ int sve2_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ i PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) default: return -1; } @@ -36,6 +52,22 @@ int sve2_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) default: return -1; } @@ -52,6 +84,22 @@ int sve2_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) default: return -1; } @@ -68,6 +116,22 @@ int sve2_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restric PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) default: return -1; } From 5523a3d4a6ad8bef90ecb04082294a76bda8c76e Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Sat, 13 Jun 2026 14:20:50 +0200 Subject: [PATCH 13/14] feat: add runtime backend dispatch with configurable block sizes --- CMakeLists.txt | 152 ++-- cmake/pernixConfig.cmake.in | 18 + include/pernix/arm64/neon/common.h | 200 ----- include/pernix/arm64/neon/unpacking.h | 100 --- include/pernix/arm64/sve/compression.h | 141 ---- include/pernix/arm64/sve/decompression.h | 141 ---- include/pernix/arm64/sve/packing.h | 9 - include/pernix/arm64/sve/unpacking.h | 9 - include/pernix/arm64/sve2/tables.h | 118 --- include/pernix/arm64/sve2/unpacking.h | 92 --- include/pernix/compat.h | 24 + include/pernix/detection.h | 87 --- include/pernix/fallback/decompression.h | 286 -------- include/pernix/pernix.h | 602 +-------------- include/pernix/pernix.hpp | 153 ++++ include/pernix/x86/avx512vbmi/compat.h | 387 ---------- include/pernix/x86/avx512vbmi/packing.h | 327 --------- include/pernix/x86/avx512vbmi/unpacking.h | 500 ------------- src/CMakeLists.txt | 210 +++++- src/arm64/neon/compression.cpp | 29 +- src/arm64/neon/decompression.cpp | 310 ++++---- src/arm64/sve/compression.cpp | 21 - src/arm64/sve/decompression.cpp | 21 - src/arm64/sve2/compression.cpp | 29 +- src/arm64/sve2/decompression.cpp | 310 ++++---- src/dispatch/cpu_features_arm.cpp | 23 + src/dispatch/cpu_features_x86.cpp | 95 +++ src/dispatch/select.cpp | 684 ++++++++++++++++++ src/fallback/compression.cpp | 149 ---- src/fallback/decompression.cpp | 150 ---- src/fallback/fallback_compression.cpp | 194 +++++ src/fallback/fallback_decompression.cpp | 201 +++++ src/internal/pernix/arm64/neon/common.h | 223 ++++++ .../internal}/pernix/arm64/neon/compression.h | 40 +- .../pernix/arm64/neon/decompression.h | 147 ++-- .../internal}/pernix/arm64/neon/packing.h | 0 .../internal}/pernix/arm64/neon/tables.h | 0 src/internal/pernix/arm64/neon/unpacking.h | 101 +++ .../internal}/pernix/arm64/sve2/compression.h | 35 +- .../pernix/arm64/sve2/decompression.h | 161 +++-- .../internal}/pernix/arm64/sve2/packing.h | 4 +- src/internal/pernix/arm64/sve2/tables.h | 119 +++ src/internal/pernix/arm64/sve2/unpacking.h | 94 +++ src/internal/pernix/dispatch/cpu_features.h | 26 + src/internal/pernix/dispatch/kernel.h | 27 + src/internal/pernix/dispatch/select.h | 159 ++++ .../pernix/fallback/avx2_compression.h | 89 +-- .../pernix/fallback/avx2_decompression.h | 248 +++++++ .../internal}/pernix/simd_compat.h | 11 +- .../pernix/x86/avx2/avx2_compression.h | 368 +++++----- .../pernix/x86/avx2/avx2_decompression.h | 98 +-- .../internal/pernix/x86/avx2/avx2_tables.h | 341 ++++----- .../x86/avx512vbmi/avx512vbmi_compression.h | 113 +-- .../x86/avx512vbmi/avx512vbmi_decompression.h | 112 +-- src/internal/pernix/x86/avx512vbmi/compat.h | 387 ++++++++++ src/internal/pernix/x86/avx512vbmi/packing.h | 331 +++++++++ .../internal}/pernix/x86/avx512vbmi/tables.h | 533 +++++++------- .../pernix/x86/avx512vbmi/unpacking.h | 500 +++++++++++++ .../pernix/x86/bmi2/bmi2_compression.h | 240 +++--- .../pernix/x86/bmi2/bmi2_decompression.h | 96 +-- {include => src/internal}/pernix/x86/utils.h | 0 src/pernix.cpp | 293 +++----- src/x86/avx2/avx2_compression.cpp | 193 +++++ src/x86/avx2/avx2_decompression.cpp | 201 +++++ src/x86/avx2/compression.cpp | 154 ---- src/x86/avx2/decompression.cpp | 153 ---- src/x86/avx512vbmi/avx512vbmi_compression.cpp | 193 +++++ .../avx512vbmi/avx512vbmi_decompression.cpp | 201 +++++ src/x86/avx512vbmi/compression.cpp | 153 ---- src/x86/avx512vbmi/decompression.cpp | 153 ---- src/x86/bmi2/bmi2_compression.cpp | 193 +++++ src/x86/bmi2/bmi2_decompression.cpp | 201 +++++ src/x86/bmi2/compression.cpp | 153 ---- src/x86/bmi2/decompression.cpp | 153 ---- tests/CMakeLists.txt | 60 +- tests/arm64/neon/decompression_tests.cpp | 38 - tests/arm64/sve/.gitkeep | 0 tests/arm64/sve2/decompression_tests.cpp | 38 - tests/fallback/compression_tests.cpp | 34 - tests/fallback/decompression_tests.cpp | 32 - tests/fallback/edge_tests.cpp | 44 -- tests/fallback_tests.cpp | 316 ++++++++ tests/include/testset.h | 34 +- tests/simd_tests.cpp | 188 +++++ tests/x86/avx2/compression_tests.cpp | 46 -- tests/x86/avx2/decompression_tests.cpp | 36 - tests/x86/avx512vbmi/compression_tests.cpp | 46 -- tests/x86/avx512vbmi/decompression_tests.cpp | 36 - tests/x86/bmi2/compression_tests.cpp | 46 -- tests/x86/bmi2/decompression_tests.cpp | 36 - 90 files changed, 7158 insertions(+), 6641 deletions(-) create mode 100644 cmake/pernixConfig.cmake.in delete mode 100644 include/pernix/arm64/neon/common.h delete mode 100644 include/pernix/arm64/neon/unpacking.h delete mode 100644 include/pernix/arm64/sve/compression.h delete mode 100644 include/pernix/arm64/sve/decompression.h delete mode 100644 include/pernix/arm64/sve/packing.h delete mode 100644 include/pernix/arm64/sve/unpacking.h delete mode 100644 include/pernix/arm64/sve2/tables.h delete mode 100644 include/pernix/arm64/sve2/unpacking.h create mode 100644 include/pernix/compat.h delete mode 100644 include/pernix/detection.h delete mode 100644 include/pernix/fallback/decompression.h create mode 100644 include/pernix/pernix.hpp delete mode 100644 include/pernix/x86/avx512vbmi/compat.h delete mode 100644 include/pernix/x86/avx512vbmi/packing.h delete mode 100644 include/pernix/x86/avx512vbmi/unpacking.h delete mode 100644 src/arm64/sve/compression.cpp delete mode 100644 src/arm64/sve/decompression.cpp create mode 100644 src/dispatch/cpu_features_arm.cpp create mode 100644 src/dispatch/cpu_features_x86.cpp create mode 100644 src/dispatch/select.cpp delete mode 100644 src/fallback/compression.cpp delete mode 100644 src/fallback/decompression.cpp create mode 100644 src/fallback/fallback_compression.cpp create mode 100644 src/fallback/fallback_decompression.cpp create mode 100644 src/internal/pernix/arm64/neon/common.h rename {include => src/internal}/pernix/arm64/neon/compression.h (79%) rename {include => src/internal}/pernix/arm64/neon/decompression.h (77%) rename {include => src/internal}/pernix/arm64/neon/packing.h (100%) rename {include => src/internal}/pernix/arm64/neon/tables.h (100%) create mode 100644 src/internal/pernix/arm64/neon/unpacking.h rename {include => src/internal}/pernix/arm64/sve2/compression.h (52%) rename {include => src/internal}/pernix/arm64/sve2/decompression.h (76%) rename {include => src/internal}/pernix/arm64/sve2/packing.h (74%) create mode 100644 src/internal/pernix/arm64/sve2/tables.h create mode 100644 src/internal/pernix/arm64/sve2/unpacking.h create mode 100644 src/internal/pernix/dispatch/cpu_features.h create mode 100644 src/internal/pernix/dispatch/kernel.h create mode 100644 src/internal/pernix/dispatch/select.h rename include/pernix/fallback/compression.h => src/internal/pernix/fallback/avx2_compression.h (73%) create mode 100644 src/internal/pernix/fallback/avx2_decompression.h rename {include => src/internal}/pernix/simd_compat.h (82%) rename include/pernix/x86/avx2/compression.h => src/internal/pernix/x86/avx2/avx2_compression.h (67%) rename include/pernix/x86/avx2/decompression.h => src/internal/pernix/x86/avx2/avx2_decompression.h (79%) rename include/pernix/x86/avx2/tables.h => src/internal/pernix/x86/avx2/avx2_tables.h (68%) rename include/pernix/x86/avx512vbmi/compression.h => src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h (88%) rename include/pernix/x86/avx512vbmi/decompression.h => src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h (87%) create mode 100644 src/internal/pernix/x86/avx512vbmi/compat.h create mode 100644 src/internal/pernix/x86/avx512vbmi/packing.h rename {include => src/internal}/pernix/x86/avx512vbmi/tables.h (51%) create mode 100644 src/internal/pernix/x86/avx512vbmi/unpacking.h rename include/pernix/x86/bmi2/compression.h => src/internal/pernix/x86/bmi2/bmi2_compression.h (56%) rename include/pernix/x86/bmi2/decompression.h => src/internal/pernix/x86/bmi2/bmi2_decompression.h (80%) rename {include => src/internal}/pernix/x86/utils.h (100%) create mode 100644 src/x86/avx2/avx2_compression.cpp create mode 100644 src/x86/avx2/avx2_decompression.cpp delete mode 100644 src/x86/avx2/compression.cpp delete mode 100644 src/x86/avx2/decompression.cpp create mode 100644 src/x86/avx512vbmi/avx512vbmi_compression.cpp create mode 100644 src/x86/avx512vbmi/avx512vbmi_decompression.cpp delete mode 100644 src/x86/avx512vbmi/compression.cpp delete mode 100644 src/x86/avx512vbmi/decompression.cpp create mode 100644 src/x86/bmi2/bmi2_compression.cpp create mode 100644 src/x86/bmi2/bmi2_decompression.cpp delete mode 100644 src/x86/bmi2/compression.cpp delete mode 100644 src/x86/bmi2/decompression.cpp delete mode 100644 tests/arm64/neon/decompression_tests.cpp delete mode 100644 tests/arm64/sve/.gitkeep delete mode 100644 tests/arm64/sve2/decompression_tests.cpp delete mode 100644 tests/fallback/compression_tests.cpp delete mode 100644 tests/fallback/decompression_tests.cpp delete mode 100644 tests/fallback/edge_tests.cpp create mode 100644 tests/fallback_tests.cpp create mode 100644 tests/simd_tests.cpp delete mode 100644 tests/x86/avx2/compression_tests.cpp delete mode 100644 tests/x86/avx2/decompression_tests.cpp delete mode 100644 tests/x86/avx512vbmi/compression_tests.cpp delete mode 100644 tests/x86/avx512vbmi/decompression_tests.cpp delete mode 100644 tests/x86/bmi2/compression_tests.cpp delete mode 100644 tests/x86/bmi2/decompression_tests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 1cbead6..c6af942 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.30) +cmake_minimum_required(VERSION 3.24) project(pernix VERSION 0.1.0 LANGUAGES CXX C) option(PERNIX_ENABLE_TESTS "Enable tests for pernix" on) @@ -8,44 +8,19 @@ option(PERNIX_ENABLE_INSTALL "Enable installation of pernix" on) option(PERNIX_ENABLE_DOXYGEN "Build documentation with Doxygen." off) option(PERNIX_INSTALL_DOCS "Install documentation with pernix" off) -option(PERNIX_DISABLE_BMI2 "Disable BMI2 optimizations" off) -option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) -option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) +option(PERNIX_ENABLE_X86_BMI2 "Build x86 BMI2 backend" ON) +option(PERNIX_ENABLE_X86_AVX2 "Build x86 AVX2 backend" ON) +option(PERNIX_ENABLE_X86_AVX512VBMI "Build x86 AVX512-VBMI backend" ON) +option(PERNIX_ENABLE_ARM64_NEON "Build arm64 NEON backend" ON) +option(PERNIX_ENABLE_ARM64_SVE2 "Build arm64 SVE2 backend" ON) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) -set(PERNIX_SIMDE_PROVIDER "AUTO" CACHE STRING "SIMDe provider when PERNIX_USE_SIMDE is enabled (AUTO, PACKAGE, FETCH)") -set_property(CACHE PERNIX_SIMDE_PROVIDER PROPERTY STRINGS AUTO PACKAGE FETCH) -set(PERNIX_ARCH_BACKEND "AUTO" CACHE STRING "Pernix architecture backend (AUTO, FALLBACK, X86, ARM64_NEON, ARM64_SVE, ARM64_SVE2)") -set_property(CACHE PERNIX_ARCH_BACKEND PROPERTY STRINGS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) option(PERNIX_ENABLE_FORTRAN_BINDINGS "Build Fortran bindings for pernix" off) -string(TOUPPER "${PERNIX_ARCH_BACKEND}" PERNIX_ARCH_BACKEND) -set(PERNIX_VALID_ARCH_BACKENDS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) -if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) - message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") -endif () - -set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") -if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") - if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_SELECTED_ARCH_BACKEND "X86") - elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") - set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") - else () - set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") - endif () -endif () -message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") - -string(TOUPPER "${PERNIX_SIMDE_PROVIDER}" PERNIX_SIMDE_PROVIDER) -set(PERNIX_VALID_SIMDE_PROVIDERS AUTO PACKAGE FETCH) -if (NOT PERNIX_SIMDE_PROVIDER IN_LIST PERNIX_VALID_SIMDE_PROVIDERS) - message(FATAL_ERROR "Unsupported PERNIX_SIMDE_PROVIDER='${PERNIX_SIMDE_PROVIDER}'. Expected one of: ${PERNIX_VALID_SIMDE_PROVIDERS}") -endif () - list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") +#[===[ set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL OFF) if (PERNIX_USE_SIMDE) if (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "PACKAGE") @@ -75,6 +50,7 @@ if (PERNIX_USE_SIMDE) message(FATAL_ERROR "PERNIX_USE_SIMDE is enabled, but simde::simde was not found. Set PERNIX_SIMDE_PROVIDER=FETCH or install SIMDe's CMake package.") endif () endif () +]===] include(CTest) include(GitVersion) @@ -94,64 +70,74 @@ else () endif () message(STATUS "Pernix version: ${VERSION}, normalized to ${NORMALIZED_VERSION}") -if (MSVC) - message(FATAL_ERROR "MSVC compiler is not supported") -else () - include(CheckCXXCompilerFlag) - set(PERNIX_PRIVATE_COMPILE_OPTIONS) - foreach (PERNIX_CXX_FLAG - -Wall - -Wextra - -Wshadow - -Wfloat-equal - -Wold-style-cast - -Wconversion - -fstrict-aliasing - -Wno-ignored-attributes - ) - string(MAKE_C_IDENTIFIER "PERNIX_HAS_CXX_FLAG_${PERNIX_CXX_FLAG}" PERNIX_CXX_FLAG_VARIABLE) - check_cxx_compiler_flag("${PERNIX_CXX_FLAG}" "${PERNIX_CXX_FLAG_VARIABLE}") - if (${PERNIX_CXX_FLAG_VARIABLE}) - list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "${PERNIX_CXX_FLAG}") - else () - message(STATUS "Compiler flag not supported: ${PERNIX_CXX_FLAG}") - endif () - endforeach () +if (NOT CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|Intel") + message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") +endif () - if (PERNIX_ENABLE_LTO) - include(CheckIPOSupported) - check_ipo_supported(RESULT PERNIX_IPO_SUPPORTED OUTPUT PERNIX_IPO_ERROR) - if (NOT PERNIX_IPO_SUPPORTED) - message(FATAL_ERROR "PERNIX_ENABLE_LTO is enabled, but IPO/LTO is not supported: ${PERNIX_IPO_ERROR}") - endif () +include(CheckCXXCompilerFlag) +set(PERNIX_PRIVATE_COMPILE_OPTIONS) +foreach (PERNIX_CXX_FLAG + -Wall + -Wextra + -Wshadow + -Wfloat-equal + -Wold-style-cast + -Wconversion + -fstrict-aliasing + -Wno-ignored-attributes +) + string(MAKE_C_IDENTIFIER "PERNIX_HAS_CXX_FLAG_${PERNIX_CXX_FLAG}" PERNIX_CXX_FLAG_VARIABLE) + check_cxx_compiler_flag("${PERNIX_CXX_FLAG}" "${PERNIX_CXX_FLAG_VARIABLE}") + if (${PERNIX_CXX_FLAG_VARIABLE}) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "${PERNIX_CXX_FLAG}") + else () + message(STATUS "Compiler flag not supported: ${PERNIX_CXX_FLAG}") + endif () +endforeach () - check_cxx_compiler_flag("-Wno-lto-type-mismatch" PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) - if (PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) - list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "-Wno-lto-type-mismatch") - endif () +if (PERNIX_ENABLE_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT PERNIX_IPO_SUPPORTED OUTPUT PERNIX_IPO_ERROR) + if (NOT PERNIX_IPO_SUPPORTED) + message(FATAL_ERROR "PERNIX_ENABLE_LTO is enabled, but IPO/LTO is not supported: ${PERNIX_IPO_ERROR}") + endif () - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - find_program(GCC_AR gcc-ar) - if (GCC_AR) - set(CMAKE_AR ${GCC_AR}) - endif () - find_program(GCC_RANLIB gcc-ranlib) - if (GCC_RANLIB) - set(CMAKE_RANLIB ${GCC_RANLIB}) - endif () - elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") - find_program(LLVM_AR llvm-ar) - if (LLVM_AR) - set(CMAKE_AR ${LLVM_AR}) - endif () - find_program(LLVM_RANLIB llvm-ranlib) - if (LLVM_RANLIB) - set(CMAKE_RANLIB ${LLVM_RANLIB}) - endif () + check_cxx_compiler_flag("-Wno-lto-type-mismatch" PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + if (PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "-Wno-lto-type-mismatch") + endif () + + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + find_program(GCC_AR gcc-ar) + if (GCC_AR) + set(CMAKE_AR ${GCC_AR}) + endif () + find_program(GCC_RANLIB gcc-ranlib) + if (GCC_RANLIB) + set(CMAKE_RANLIB ${GCC_RANLIB}) + endif () + elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + find_program(LLVM_AR llvm-ar) + if (LLVM_AR) + set(CMAKE_AR ${LLVM_AR}) + endif () + find_program(LLVM_RANLIB llvm-ranlib) + if (LLVM_RANLIB) + set(CMAKE_RANLIB ${LLVM_RANLIB}) endif () endif () endif () +set(PERNIX_TARGET_IS_X86 OFF) +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_TARGET_IS_X86 ON) +endif () + +set(PERNIX_TARGET_IS_ARM64 OFF) +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_TARGET_IS_ARM64 ON) +endif () + add_subdirectory(src) if (PERNIX_ENABLE_FORTRAN_BINDINGS) diff --git a/cmake/pernixConfig.cmake.in b/cmake/pernixConfig.cmake.in new file mode 100644 index 0000000..fe3eb88 --- /dev/null +++ b/cmake/pernixConfig.cmake.in @@ -0,0 +1,18 @@ +@PACKAGE_INIT@ + +if (@PERNIX_USE_SIMDE@) + find_package(simde CONFIG QUIET) + if (NOT TARGET simde::simde AND @PERNIX_BUNDLE_SIMDE_FOR_INSTALL@) + add_library(simde::simde INTERFACE IMPORTED) + target_include_directories(simde::simde INTERFACE "${PACKAGE_PREFIX_DIR}/include") + endif () + if (NOT TARGET simde::simde) + set(pernix_FOUND FALSE) + set(pernix_NOT_FOUND_MESSAGE "pernix was built with SIMDe support, but simde::simde was not found. Install SIMDe's CMake package or use a Pernix package built with bundled SIMDe headers.") + return() + endif () +endif () + +include("${CMAKE_CURRENT_LIST_DIR}/pernixTargets.cmake") + +check_required_components(pernix) diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h deleted file mode 100644 index 8e517fa..0000000 --- a/include/pernix/arm64/neon/common.h +++ /dev/null @@ -1,200 +0,0 @@ -#ifndef PERNIX_ARM64_NEON_COMMON_H -#define PERNIX_ARM64_NEON_COMMON_H - -#include - -#include - -namespace pernix::arm64::neon::internal { -struct float64x2x8_t { - float64x2_t val[8]; -}; - -static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - return tail_bytes; -} - -__always_inline int32x4x4_t neon_convert_int8x16_int32x4x4(const int8x16_t& input) { - const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); - const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); - - return {{ - vmovl_s16(vget_low_s16(s16_lo)), - vmovl_s16(vget_high_s16(s16_lo)), - vmovl_s16(vget_low_s16(s16_hi)), - vmovl_s16(vget_high_s16(s16_hi)), - }}; -} - -__always_inline int32x4x2_t neon_convert_int16x8_int32x4x2(const int16x8_t& input) { - return {{ - vmovl_s16(vget_low_s16(input)), - vmovl_s16(vget_high_s16(input)), - }}; -} - -__always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, const float32x4_t& scale) { - return {{ - vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), - vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), - vmulq_f32(vcvtq_f32_s32(input.val[2]), scale), - vmulq_f32(vcvtq_f32_s32(input.val[3]), scale), - }}; -} - -__always_inline float32x4x2_t neon_dequantize_epi32(const int32x4x2_t& input, const float32x4_t& scale) { - return {{ - vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), - vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), - }}; -} - -__always_inline float32x4_t neon_dequantize_epi32(const int32x4_t& input, const float32x4_t& scale) { - return vmulq_f32(vcvtq_f32_s32(input), scale); -} - -__always_inline float64x2_t neon_dequantize_epi32_f64(const int32x2_t& input, const float64x2_t& scale) { - return vmulq_f64(vcvtq_f64_s64(vmovl_s32(input)), scale); -} - -__always_inline float64x2x2_t neon_dequantize_epi32_f64(const int32x4_t& input, const float64x2_t& scale) { - return {{ - neon_dequantize_epi32_f64(vget_low_s32(input), scale), - neon_dequantize_epi32_f64(vget_high_s32(input), scale), - }}; -} - -__always_inline float64x2x4_t neon_dequantize_epi32_f64(const int32x4x2_t& input, const float64x2_t& scale) { - const float64x2x2_t dequantized_low = neon_dequantize_epi32_f64(input.val[0], scale); - const float64x2x2_t dequantized_high = neon_dequantize_epi32_f64(input.val[1], scale); - - return {{ - dequantized_low.val[0], - dequantized_low.val[1], - dequantized_high.val[0], - dequantized_high.val[1], - }}; -} - -__always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t& input, const float64x2_t& scale) { - const float64x2x2_t dequantized0 = neon_dequantize_epi32_f64(input.val[0], scale); - const float64x2x2_t dequantized1 = neon_dequantize_epi32_f64(input.val[1], scale); - const float64x2x2_t dequantized2 = neon_dequantize_epi32_f64(input.val[2], scale); - const float64x2x2_t dequantized3 = neon_dequantize_epi32_f64(input.val[3], scale); - - return {{ - dequantized0.val[0], - dequantized0.val[1], - dequantized1.val[0], - dequantized1.val[1], - dequantized2.val[0], - dequantized2.val[1], - dequantized3.val[0], - dequantized3.val[1], - }}; -} - -__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_bytes_count) { - uint8_t buffer[16] = {0}; - std::memcpy(buffer, input, tail_bytes_count); - return vld1q_u8(buffer); -} - -__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_bytes_count) { - uint16_t buffer[8] = {0}; - std::memcpy(buffer, input, tail_bytes_count); - return vld1q_u16(buffer); -} - -__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_bytes_count) { - uint32_t buffer[4] = {0}; - std::memcpy(buffer, input, tail_bytes_count); - return vld1q_u32(buffer); -} - -__always_inline float32x4_t neon_load_tail_elements_f32(const uint8_t* input, const uint32_t tail_elements) { - float32_t buffer[4] = {0.0f}; - std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); - return vld1q_f32(buffer); -} - -__always_inline float64x2_t neon_load_tail_elements_f64(const uint8_t* input, const uint32_t tail_elements) { - float64_t buffer[2] = {0.0}; - std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); - return vld1q_f64(buffer); -} - -__always_inline void neon_store_tail_elements_int8(uint8_t* output, const uint8x16x4_t& data, const uint32_t tail_elements) { - uint8_t buffer[16 * 4]; - for (uint32_t i = 0; i < 4; ++i) { - vst1q_u8(buffer + i * 16, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(uint8_t)); -} - -__always_inline void neon_store_tail_elements_int16(uint16_t* output, const uint16x8x4_t& data, const uint32_t tail_elements) { - uint16_t buffer[8 * 4]; - for (uint32_t i = 0; i < 4; ++i) { - vst1q_u16(buffer + i * 8, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(uint16_t)); -} - -__always_inline void neon_store_tail_elements_int32(uint32_t* output, const uint32x4x4_t& data, const uint32_t tail_elements) { - uint32_t buffer[4 * 4]; - for (uint32_t i = 0; i < 4; ++i) { - vst1q_u32(buffer + i * 4, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(uint32_t)); -} - -__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x4_t& data, const uint32_t tail_elements) { - float32_t buffer[16 * 4]; - for (uint32_t i = 0; i < 4; ++i) { - vst1q_f32(buffer + i * 4, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); -} - -__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x2_t& data, const uint32_t tail_elements) { - float32_t buffer[8 * 2]; - for (uint32_t i = 0; i < 2; ++i) { - vst1q_f32(buffer + i * 4, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); -} - -__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4_t& data, const uint32_t tail_elements) { - float32_t buffer[4]; - vst1q_f32(buffer, data); - std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); -} - -__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { - float64_t buffer[2 * 4]; - for (uint32_t i = 0; i < 4; ++i) { - vst1q_f64(buffer + i * 2, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); -} - -__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x2_t& data, const uint32_t tail_elements) { - float64_t buffer[2 * 2]; - for (uint32_t i = 0; i < 2; ++i) { - vst1q_f64(buffer + i * 2, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); -} - -__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x8_t& data, const uint32_t tail_elements) { - float64_t buffer[2 * 8]; - for (uint32_t i = 0; i < 8; ++i) { - vst1q_f64(buffer + i * 2, data.val[i]); - } - std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); -} -} // namespace pernix::arm64::neon::internal - -#endif // PERNIX_ARM64_NEON_COMMON_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h deleted file mode 100644 index 6ac0e20..0000000 --- a/include/pernix/arm64/neon/unpacking.h +++ /dev/null @@ -1,100 +0,0 @@ -#ifndef PERNIX_ARM64_NEON_UNPACKING_H -#define PERNIX_ARM64_NEON_UNPACKING_H - -#include -#include - -using namespace pernix::arm64::neon::internal; - -namespace pernix::arm64::neon::internal::b128 { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { - if constexpr (BIT_WIDTH == 8) { - return vreinterpretq_s8_u8(input); - } else if constexpr (BIT_WIDTH == 1) { - using tables = table_unpacking; - - const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); - const uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); - - return vreinterpretq_s8_u8(vandq_u8(shifted, vdupq_n_u8(1))); - } else { - using tables = table_unpacking; - - const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); - - uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); - - if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { - const uint8x16_t permuted2_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute2.data())); - - shifted = vorrq_u8(shifted, vshlq_u8(permuted2_u8, vld1q_s8(tables::shift2.data()))); - } - - constexpr int shift = 8 - BIT_WIDTH; - shifted = vshlq_n_u8(shifted, shift); - - if constexpr (SIGN_VALUES) { - return vshlq_s8(vreinterpretq_s8_u8(shifted), vdupq_n_s8(-shift)); - } else { - return vreinterpretq_s8_u8(vshlq_u8(shifted, vdupq_n_s8(-shift))); - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t& input) { - if constexpr (BIT_WIDTH == 16) { - return vreinterpretq_s16_u16(input); - } else { - using tables = table_unpacking; - - const uint8x16_t input_u8 = vreinterpretq_u8_u16(input); - - const uint8x16_t permuted1_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute1.data())); - - uint16x8_t shifted = vshlq_u16(vreinterpretq_u16_u8(permuted1_u8), vld1q_s16(tables::shift1.data())); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const uint8x16_t permuted2_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute2.data())); - - const uint16x8_t shifted2 = vshlq_u16(vreinterpretq_u16_u8(permuted2_u8), vld1q_s16(tables::shift2.data())); - - shifted = vorrq_u16(shifted, shifted2); - } - - constexpr int shift = 16 - BIT_WIDTH; - shifted = vshlq_n_u16(shifted, shift); - - if constexpr (SIGN_VALUES) { - return vshlq_s16(vreinterpretq_s16_u16(shifted), vdupq_n_s16(-shift)); - } else { - return vreinterpretq_s16_u16(vshlq_u16(shifted, vdupq_n_s16(-shift))); - } - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t& input) { - using tables = table_unpacking; - - const uint8x16_t input_8 = vreinterpretq_u8_u32(input); - - const uint8x16_t permuted_u8 = vqtbl1q_u8(input_8, vld1q_u8(tables::permute.data())); - - const uint32x4_t value = vshlq_u32(vreinterpretq_u32_u8(permuted_u8), vld1q_s32(tables::shift.data())); - - if constexpr (SIGN_VALUES) { - constexpr int sign_shift = 32 - BIT_WIDTH; - return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); - } else { - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; - return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); - } -} -} // namespace pernix::arm64::neon::internal::b128 - -#endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve/compression.h b/include/pernix/arm64/sve/compression.h deleted file mode 100644 index cf83ce0..0000000 --- a/include/pernix/arm64/sve/compression.h +++ /dev/null @@ -1,141 +0,0 @@ -#ifndef PERNIX_ARM64_SVE_COMPRESSION_H -#define PERNIX_ARM64_SVE_COMPRESSION_H - -#include -#include - -#include -#include - -namespace pernix::arm64::sve { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} -} // namespace internal - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::sve_compress_block_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::sve_compress_block_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::sve_compress_block_17to24(input, scale, output); - } - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::sve_compress_block_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::sve_compress_block_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::sve_compress_block_17to24(input, scale, output); - } - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve_compress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve_compress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - return 0; -} - -#ifdef __cplusplus -extern "C" { -#endif - -int sve_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - - -int sve_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); - -int sve_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -int sve_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -#endif -} // namespace pernix::arm64::sve - -#endif // PERNIX_ARM64_SVE_COMPRESSION_H diff --git a/include/pernix/arm64/sve/decompression.h b/include/pernix/arm64/sve/decompression.h deleted file mode 100644 index 052a3e4..0000000 --- a/include/pernix/arm64/sve/decompression.h +++ /dev/null @@ -1,141 +0,0 @@ -#ifndef PERNIX_ARM64_SVE_DECOMPRESSION_H -#define PERNIX_ARM64_SVE_DECOMPRESSION_H - -#include -#include - -#include -#include - -namespace pernix::arm64::sve { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} -} // namespace internal - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::sve_decompress_block_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::sve_decompress_block_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::sve_decompress_block_17to24(input, scale, output); - } - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::sve_decompress_block_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::sve_decompress_block_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::sve_decompress_block_17to24(input, scale, output); - } - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - return 0; -} - -#ifdef __cplusplus -extern "C" { -#endif - -int sve_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - - -int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); - -int sve_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -#endif -} // namespace pernix::arm64::sve - -#endif // PERNIX_ARM64_SVE_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve/packing.h b/include/pernix/arm64/sve/packing.h deleted file mode 100644 index ab57b4f..0000000 --- a/include/pernix/arm64/sve/packing.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef PERNIX_ARM64_SVE_PACKING_H -#define PERNIX_ARM64_SVE_PACKING_H - -#include - -namespace pernix::arm64::sve::internal { -} // namespace pernix::arm64::sve::internal - -#endif // PERNIX_ARM64_SVE_PACKING_H diff --git a/include/pernix/arm64/sve/unpacking.h b/include/pernix/arm64/sve/unpacking.h deleted file mode 100644 index 2565ab7..0000000 --- a/include/pernix/arm64/sve/unpacking.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef PERNIX_ARM64_SVE_UNPACKING_H -#define PERNIX_ARM64_SVE_UNPACKING_H - -#include - -namespace pernix::arm64::sve::internal { -} // namespace pernix::arm64::sve::internal - -#endif // PERNIX_ARM64_SVE_UNPACKING_H diff --git a/include/pernix/arm64/sve2/tables.h b/include/pernix/arm64/sve2/tables.h deleted file mode 100644 index 897fa9b..0000000 --- a/include/pernix/arm64/sve2/tables.h +++ /dev/null @@ -1,118 +0,0 @@ -#ifndef PERNIX_ARM64_SVE2_TABLES_H -#define PERNIX_ARM64_SVE2_TABLES_H - -#include - -#include - -namespace pernix::arm64::sve2::internal { -template -struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svbool_t pg_b8() { return svptrue_b8(); } - - static svbool_t pg_b16() { return svptrue_b16(); } - - static svbool_t pg_b32() { return svptrue_b32(); } -}; - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svuint8_t permute() { - const svbool_t pg = svptrue_b8(); - return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); - } - - static svuint8_t spill_permute() { - const svbool_t pg = svptrue_b8(); - return svadd_n_u8_x(pg, permute(), 1); - } - - static svuint8_t shift() { - const svbool_t pg = svptrue_b8(); - return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); - } - - static svuint8_t spill_shift() { - const svbool_t pg = svptrue_b8(); - return svsub_u8_x(pg, svdup_n_u8(8), shift()); - } -}; - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svuint8_t permute() { - const svbool_t pg = svptrue_b8(); - const svuint8_t lane = svindex_u8(0, 1); - const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); - const svuint8_t byte = svand_n_u8_x(pg, lane, 1); - - svuint8_t first; - if constexpr (BIT_WIDTH == 16) { - first = svlsl_n_u8_x(pg, elem, 1); - } else { - constexpr uint8_t extra_bits = BIT_WIDTH - 8u; - const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); - const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); - first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); - } - - return svadd_u8_x(pg, first, byte); - } - - static svuint8_t spill_permute() { - const svbool_t pg = svptrue_b8(); - return svadd_n_u8_x(pg, permute(), 2); - } - - static svuint16_t shift() { - const svbool_t pg = svptrue_b16(); - return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); - } - - static svuint16_t spill_shift() { - const svbool_t pg = svptrue_b16(); - const svuint16_t bit_shift = shift(); - const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); - return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); - } -}; - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) -struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svuint8_t permute() { - const svbool_t pg = svptrue_b8(); - const svuint8_t lane = svindex_u8(0, 1); - const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); - const svuint8_t byte = svand_n_u8_x(pg, lane, 3); - - svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); - if constexpr (BIT_WIDTH % 8u != 0) { - constexpr uint8_t extra_bits = BIT_WIDTH % 8u; - const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); - const svuint8_t low_bits = - svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); - first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); - } - - return svadd_u8_x(pg, first, byte); - } - - static svuint32_t shift() { - const svbool_t pg = svptrue_b32(); - return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), START_BIT_OFFSET), 7); - } -}; -} // namespace pernix::arm64::sve2::internal - -#endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/include/pernix/arm64/sve2/unpacking.h b/include/pernix/arm64/sve2/unpacking.h deleted file mode 100644 index 326901f..0000000 --- a/include/pernix/arm64/sve2/unpacking.h +++ /dev/null @@ -1,92 +0,0 @@ -#ifndef PERNIX_ARM64_SVE2_UNPACKING_H -#define PERNIX_ARM64_SVE2_UNPACKING_H - -#include - -#include "tables.h" - -namespace pernix::arm64::sve2::internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, - const svuint8_t spill_permute, const svuint8_t spill_shift) { - if constexpr (BIT_WIDTH == 8) { - return svreinterpret_s8(input); - } else { - const svbool_t pg = svptrue_b8(); - - const svuint8_t permuted = svtbl_u8(input, permute); - svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); - - if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { - const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); - const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); - unpacked = svorr_u8_x(pg, unpacked, spill_shifted); - } - - if constexpr (BIT_WIDTH == 1) { - unpacked = svand_n_u8_x(pg, unpacked, 1); - return svreinterpret_s8(unpacked); - } else { - constexpr int sign_shift = 8 - BIT_WIDTH; - - unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); - - if constexpr (SIGN_VALUES) { - return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); - } else { - return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); - } - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, const svuint16_t shift, - const svuint8_t spill_permute, const svuint16_t spill_shift) { - if constexpr (BIT_WIDTH == 16) { - return svreinterpret_s16(input); - } else { - const svbool_t pg = svptrue_b16(); - - const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); - svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); - const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), spill_shift); - shifted = svorr_u16_x(pg, shifted, spill_shifted); - } - - constexpr int sign_shift = 16 - BIT_WIDTH; - shifted = svlsl_n_u16_x(pg, shifted, sign_shift); - - if constexpr (SIGN_VALUES) { - return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); - } else { - return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); - } - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { - using table = table_unpacking; - - const svbool_t pg = svptrue_b32(); - const svuint8_t permuted = svtbl_u8(input, table::permute()); - const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); - - if constexpr (SIGN_VALUES) { - constexpr int sign_shift = 32 - BIT_WIDTH; - return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); - } else { - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; - return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); - } -} -} // namespace pernix::arm64::sve2::internal - -#endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/include/pernix/compat.h b/include/pernix/compat.h new file mode 100644 index 0000000..42544ee --- /dev/null +++ b/include/pernix/compat.h @@ -0,0 +1,24 @@ +#ifndef PERNIX_COMPAT_H +#define PERNIX_COMPAT_H + +#ifndef __always_inline +#if defined(__GNUC__) || defined(__clang__) +#define __always_inline inline __attribute__((always_inline)) +#elif defined(_MSC_VER) +#define __always_inline __forceinline +#else +#define __always_inline inline +#endif +#endif + +#if defined(_WIN32) && defined(PERNIX_SHARED) +#if defined(PERNIX_BUILD_LIB) +#define PERNIX_API __declspec(dllexport) +#else +#define PERNIX_API __declspec(dllimport) +#endif +#else +#define PERNIX_API +#endif + +#endif //PERNIX_COMPAT_H diff --git a/include/pernix/detection.h b/include/pernix/detection.h deleted file mode 100644 index fa9cd44..0000000 --- a/include/pernix/detection.h +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef PERNIX_HELPER_H -#define PERNIX_HELPER_H - -#include - -// Internal capability tiers used to derive the public PERNIX_* feature macros. -#define PERNIX_MACHINE_ID_GENERIC 0 -#define PERNIX_MACHINE_ID_V2 1 -#define PERNIX_MACHINE_ID_V3 2 -#define PERNIX_MACHINE_ID_V4 3 -#define PERNIX_MACHINE_ID_V4_VBMI 4 - -#if defined(PERNIX_BACKEND_ARM64_NEON) -#define PERNIX_ARM64_NEON_ENABLED -#endif - -#if defined(PERNIX_BACKEND_ARM64_SVE) -#define PERNIX_ARM64_SVE_ENABLED -#endif - -#if defined(PERNIX_BACKEND_ARM64_SVE2) -#define PERNIX_ARM64_SVE2_ENABLED -#endif - -#if defined(PERNIX_BACKEND_X86) -// Map the compiler's enabled ISA set to the highest supported Pernix target level. -#if (__SSE3__ && __SSE4_1__ && __SSE4_2__) -#if (__AVX__ && __AVX2__ && __FMA__ && __BMI__ && __BMI2__) -#if (__AVX512BW__ && __AVX512CD__ && __AVX512DQ__ && __AVX512F__ && __AVX512VL__) -#if (__AVX512VBMI__) -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V4_VBMI -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V4 -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V3 -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V2 -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC -#endif - -// Feature-selection macros consumed by the public headers. -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V2) -#define PERNIX_SSE_ENABLED -#endif -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V3) -#define PERNIX_AVX2_ENABLED -#define PERNIX_BMI2_ENABLED -#endif -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V4) -#define PERNIX_AVX512_ENABLED -#endif -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V4_VBMI) -#define PERNIX_AVX512_VBMI_ENABLED -#endif - -#if defined(PERNIX_USE_SIMDE) && defined(PERNIX_BACKEND_X86) -#define PERNIX_SSE_ENABLED -#define PERNIX_AVX2_ENABLED -#define PERNIX_BMI2_ENABLED -#define PERNIX_AVX512_ENABLED -#define PERNIX_AVX512_VBMI_ENABLED -#endif - -// Allow build systems or tests to force lower-tier implementations. -#ifdef PERNIX_DISABLE_AVX512 -#undef PERNIX_AVX512_ENABLED -#undef PERNIX_AVX512_VBMI_ENABLED -#endif -#ifdef PERNIX_DISABLE_AVX2 -#undef PERNIX_AVX2_ENABLED -#endif -#ifdef PERNIX_DISABLE_BMI2 -#undef PERNIX_BMI2_ENABLED -#endif - -#endif // PERNIX_HELPER_H diff --git a/include/pernix/fallback/decompression.h b/include/pernix/fallback/decompression.h deleted file mode 100644 index c4714e3..0000000 --- a/include/pernix/fallback/decompression.h +++ /dev/null @@ -1,286 +0,0 @@ -#ifndef PERNIX_FALLBACK_DECOMPRESSION_H -#define PERNIX_FALLBACK_DECOMPRESSION_H - -#include - -#include -#include -#include -#include - -namespace pernix { -namespace internal { -/** - * @brief Dequantize a single int32_t value to float using the provided scale. - * - * @param input input int32_t value to be dequantized. - * @param scale scaling factor used during quantization. - * @return float dequantized float value. - */ -__always_inline float dequantize_epi32(const int32_t input, const float scale) { - return static_cast(input) * scale; -} - -/** - * @brief Dequantize a single int64_t value to double using the provided scale. - * - * @param input input int64_t value to be dequantized. - * @param scale scaling factor used during quantization. - * @return double_t dequantized double value. - */ -__always_inline double_t dequantize_epi64(const int64_t input, const double_t scale) { - return static_cast(input) * scale; -} - -/** - * @brief Sign-extend a packed integer value stored in the low bits of a 32-bit word. - * - * @tparam BIT_WIDTH number of significant bits in the encoded value. - * @param value unsigned packed value. - * @return int32_t sign-extended value. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline auto sign_extend(const uint32_t value) -> int32_t { - if constexpr (BIT_WIDTH == 1) { - return static_cast(value & 1U); - } - - constexpr uint32_t sign_bit = uint32_t{1} << (BIT_WIDTH - 1); - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1; - const uint32_t masked = value & mask; - return static_cast((static_cast(masked ^ sign_bit)) - static_cast(sign_bit)); -} - -/** - * @brief Unpack bit-packed values from a typed input span into signed 32-bit integers. - * - * @tparam T unsigned integer type used to read the source buffer. - * @tparam BIT_WIDTH bit width per packed value. - * @tparam SIGN_VALUES whether to sign-extend unpacked values. - * @param input pointer to the typed packed input buffer. - * @param bit_offset starting bit offset in the first input word. - * @param elements number of values to unpack. - * @return std::vector unpacked values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) -__always_inline auto unpack_epi32_fallback_inner(const uint8_t* __restrict__ input, const uint8_t bit_offset, const std::size_t elements) - -> std::vector { - constexpr uint32_t bits_in_type = sizeof(T) * 8; - constexpr uint32_t bitmask = BIT_WIDTH == bits_in_type ? std::numeric_limits::max() : (1U << BIT_WIDTH) - 1U; - - std::vector output(elements); - - std::size_t idx = 0; - uint8_t bits_in_buffer = 8 - bit_offset; - uint64_t buffer = static_cast(input[idx++]) >> bit_offset; - -#pragma GCC unroll 64 - for (uint32_t i = 0; i < elements; i++) { - while (BIT_WIDTH > bits_in_buffer) { - const auto next_value = static_cast(input[idx++]) << bits_in_buffer; - buffer |= next_value; - bits_in_buffer += 8; - } - - const uint32_t raw_value = static_cast(buffer & bitmask); - if constexpr (SIGN_VALUES) { - output[i] = sign_extend(raw_value); - } else { - output[i] = static_cast(raw_value); - } - - buffer >>= BIT_WIDTH; - bits_in_buffer -= BIT_WIDTH; - } - - return output; -} - -/** - * @brief Unpack packed int32_t values from the input buffer using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the packed data. - * @param elements number of elements to unpack. - * @return std::vector unpacked int32_t values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline auto unpack_epi32_fallback(const uint8_t* __restrict__ input, const std::size_t elements) -> std::vector { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return unpack_epi32_fallback_inner(input, 0, elements); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return unpack_epi32_fallback_inner(input, 0, elements); - } else { - return unpack_epi32_fallback_inner(input, 0, elements); - } -} -} // namespace internal - -/** - * @brief Decompress a single 512\-bit block using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_block_fallback(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const std::vector block_values = internal::unpack_epi32_fallback(input, elements_per_block); - -#pragma GCC unroll 512 - for (uint32_t i = 0; i < elements_per_block; i++) { - output[i] = internal::dequantize_epi32(block_values[i], scale); - } - - return 0; -} - -/** - * @brief Decompress a single block to double values using the fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_block_fallback(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const std::vector block_values = internal::unpack_epi32_fallback(input, elements_per_block); - -#pragma GCC unroll 512 - for (uint32_t i = 0; i < elements_per_block; i++) { - output[i] = internal::dequantize_epi64(block_values[i], scale); - } - - return 0; -} - -/** - * @brief Decompress multiple 512\-bit blocks using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int decompress_blocks_fallback(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - decompress_block_fallback(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -/** - * @brief Decompress multiple blocks to double values using the fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of blocks to decompress. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_blocks_fallback(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - decompress_block_fallback(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Decompress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - */ -int decompress_block_fallback(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - */ -int decompress_block_fallback_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - */ -int decompress_blocks_fallback(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - */ -int decompress_blocks_fallback_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_FALLBACK_DECOMPRESSION_H diff --git a/include/pernix/pernix.h b/include/pernix/pernix.h index 4998a60..7a64af6 100644 --- a/include/pernix/pernix.h +++ b/include/pernix/pernix.h @@ -1,582 +1,58 @@ #ifndef PERNIX_H #define PERNIX_H -#include +#include +#include +#include -// Include architecture-specific headers based on detected capabilities -// AVX2 -#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) -#include -#include - -// BMI2: Needs AVX2 as well -#ifdef PERNIX_BMI2_ENABLED -#include -#include -#endif // PERNIX_BMI2_ENABLED - -// AVX512 VBMI: Needs AVX2 as well -#ifdef PERNIX_AVX512_VBMI_ENABLED -#include -#include -#endif // PERNIX_AVX512_VBMI_ENABLED - -#endif // PERNIX_BACKEND_X86 && PERNIX_AVX2_ENABLED - -#ifdef PERNIX_BACKEND_ARM64_NEON -#include -#include -#endif - -#ifdef PERNIX_BACKEND_ARM64_SVE -#include -#include -#endif - -#ifdef PERNIX_BACKEND_ARM64_SVE2 -#include -#include -#endif - -// Fallback (non-SIMD) implementations -#include -#include - -namespace pernix { -/** - * @brief Compress a single block of floating-point data into a bit-packed format using the specified bit width and scale. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * - * @return int status code (0 for success). - * - * @note This function will dispatch to the best available implementation based on detected CPU features at compile time. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single block of double-precision values into a bit-packed representation. - * - * @tparam BIT_WIDTH bit width per quantized value (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the input block. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple blocks of single-precision values. - * - * @tparam BIT_WIDTH bit width per quantized value (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, uint32_t blocks); - -/** - * @brief Compress multiple blocks of double-precision values. - * - * @tparam BIT_WIDTH bit width per quantized value (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, uint32_t blocks); - -/** - * @brief Decompress a single block of packed values into single-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single block of packed values into double-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple blocks of packed values into single-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); - -/** - * @brief Decompress multiple blocks of packed values into double-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); - -// Use the best available implementation based on detected CPU features at compile time. -#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) -#ifdef PERNIX_AVX512_VBMI_ENABLED -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm512_compress_blocks_avx512vbmi(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm512_compress_blocks_avx512vbmi(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm512_decompress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm512_decompress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return mm512_decompress_blocks_avx512vbmi(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return mm512_decompress_blocks_avx512vbmi(input, scale, output, blocks); -} -#else -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm256_compress_blocks_avx2(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm256_compress_blocks_avx2(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm256_decompress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm256_decompress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return mm256_decompress_blocks_avx2(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return mm256_decompress_blocks_avx2(input, scale, output, blocks); -} +#if defined(__cplusplus) +extern "C" { #endif -#elif defined(PERNIX_BACKEND_ARM64_NEON) -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return neon_compress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return neon_compress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return neon_compress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return neon_compress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return neon_decompress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return neon_decompress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return neon_decompress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return neon_decompress_blocks(input, scale, output, blocks); -} -#elif defined(PERNIX_BACKEND_ARM64_SVE) -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return sve_compress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return sve_compress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return sve_compress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return sve_compress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return sve_decompress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return sve_decompress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return sve_decompress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return sve_decompress_blocks(input, scale, output, blocks); -} -#elif defined(PERNIX_BACKEND_ARM64_SVE2) -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return sve2_compress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return sve2_compress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return sve2_compress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return sve2_compress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return sve2_decompress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return sve2_decompress_block(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return sve2_decompress_blocks(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return sve2_decompress_blocks(input, scale, output, blocks); -} -#else -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback(input, scale, output); -} -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return compress_blocks_fallback(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return compress_blocks_fallback(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return decompress_block_fallback(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return decompress_block_fallback(input, scale, output); -} +typedef enum pernix_status { + PERNIX_STATUS_OK = 0, + PERNIX_STATUS_INVALID_ARGUMENT = -1, + PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH = -2, + PERNIX_STATUS_UNSUPPORTED_BACKEND = -3, + PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE = -4 +} pernix_status; -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return decompress_blocks_fallback(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return decompress_blocks_fallback(input, scale, output, blocks); -} -#endif -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif +typedef enum pernix_backend { + PERNIX_BACKEND_AUTO = 0, + PERNIX_BACKEND_FALLBACK = 1, + PERNIX_BACKEND_X86_AVX2 = 2, + PERNIX_BACKEND_X86_BMI2 = 3, + PERNIX_BACKEND_X86_AVX512_VBMI = 4, + PERNIX_BACKEND_ARM64_NEON = 5, + PERNIX_BACKEND_ARM64_SVE = 6 +} pernix_backend; -/** - * @brief C ABI wrapper for compressing one single-precision block. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the input block. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +PERNIX_API pernix_status pernix_compress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output); -/** - * @brief C ABI wrapper for compressing one double-precision block. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the input block. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +PERNIX_API pernix_status pernix_compress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output, uint32_t blocks); -/** - * @brief C ABI wrapper for compressing multiple single-precision blocks. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, uint32_t blocks); +PERNIX_API pernix_status pernix_decompress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output, bool sign_values); -/** - * @brief C ABI wrapper for compressing multiple double-precision blocks. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); +PERNIX_API pernix_status pernix_decompress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output, uint32_t blocks, bool sign_values); -/** - * @brief C ABI wrapper for decompressing one single-precision block. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +PERNIX_API pernix_status pernix_compress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output); -/** - * @brief C ABI wrapper for decompressing one double-precision block. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +PERNIX_API pernix_status pernix_compress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output, uint32_t blocks); -/** - * @brief C ABI wrapper for decompressing multiple single-precision blocks. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); +PERNIX_API pernix_status pernix_decompress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output, bool sign_values); -/** - * @brief C ABI wrapper for decompressing multiple double-precision blocks. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); +PERNIX_API pernix_status pernix_decompress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output, uint32_t blocks, bool sign_values); -#ifdef __cplusplus +#if defined(__cplusplus) } -} // namespace pernix #endif -#endif // PERNIX_H +#endif //PERNIX_H diff --git a/include/pernix/pernix.hpp b/include/pernix/pernix.hpp new file mode 100644 index 0000000..67301bb --- /dev/null +++ b/include/pernix/pernix.hpp @@ -0,0 +1,153 @@ +#ifndef PERNIX_HPP +#define PERNIX_HPP +#include + +namespace pernix { +enum class Backend { + Auto = PERNIX_BACKEND_AUTO, + Fallback = PERNIX_BACKEND_FALLBACK, + X86Avx2 = PERNIX_BACKEND_X86_AVX2, + X86Bmi2 = PERNIX_BACKEND_X86_BMI2, + X86Avx512Vbmi = PERNIX_BACKEND_X86_AVX512_VBMI, + Arm64Neon = PERNIX_BACKEND_ARM64_NEON, + Arm64Sve = PERNIX_BACKEND_ARM64_SVE +}; + +__always_inline int compress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const float scale, std::span output) { + return pernix_compress_block_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data()); +} + +__always_inline int compress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const double scale, std::span output) { + return pernix_compress_block_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data()); +} + +__always_inline int decompress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const float scale, std::span output, + const bool sign_values = true) { + return pernix_decompress_block_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + sign_values); +} + +__always_inline int decompress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const double scale, std::span output, + const bool sign_values = true) { + return pernix_decompress_block_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + sign_values); +} + +__always_inline int compress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const float scale, std::span output, + const uint32_t blocks) { + return pernix_compress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + blocks); +} + +__always_inline int compress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const double scale, std::span output, + const uint32_t blocks) { + return pernix_compress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + blocks); +} + +__always_inline int decompress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const float scale, std::span output, + const uint32_t blocks, const bool sign_values = true) { + return pernix_decompress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + blocks, sign_values); +} + +__always_inline int decompress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, + const std::span input, const double scale, std::span output, + const uint32_t blocks, const bool sign_values = true) { + return pernix_decompress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + blocks, sign_values); +} + +// convenience overloads without backend (defaults to Auto) +__always_inline int compress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const float scale, const std::span output) { + return compress_block(Backend::Auto, bit_width, block_size, input, scale, output); +} + +__always_inline int compress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const double scale, const std::span output) { + return compress_block(Backend::Auto, bit_width, block_size, input, scale, output); +} + +__always_inline int decompress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const float scale, const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, block_size, input, scale, output, sign_values); +} + +__always_inline int decompress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const double scale, const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, block_size, input, scale, output, sign_values); +} + +__always_inline int compress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const float scale, const std::span output, const uint32_t blocks) { + return compress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks); +} + +__always_inline int compress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const double scale, const std::span output, const uint32_t blocks) { + return compress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks); +} + +__always_inline int decompress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const float scale, const std::span output, const uint32_t blocks, + const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks, sign_values); +} + +__always_inline int decompress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, + const double scale, const std::span output, const uint32_t blocks, + const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks, sign_values); +} + +// convenience overloads without backend and without block_size (defaults to 64) +__always_inline int compress_block(const uint8_t bit_width, const std::span input, const float scale, + const std::span output) { + return compress_block(Backend::Auto, bit_width, 64, input, scale, output); +} + +__always_inline int compress_block(const uint8_t bit_width, const std::span input, const double scale, + const std::span output) { + return compress_block(Backend::Auto, bit_width, 64, input, scale, output); +} + +__always_inline int decompress_block(const uint8_t bit_width, const std::span input, const float scale, + const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, 64, input, scale, output, sign_values); +} + +__always_inline int decompress_block(const uint8_t bit_width, const std::span input, const double scale, + const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, 64, input, scale, output, sign_values); +} + +__always_inline int compress_blocks(const uint8_t bit_width, const std::span input, const float scale, + const std::span output, const uint32_t blocks) { + return compress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks); +} + +__always_inline int compress_blocks(const uint8_t bit_width, const std::span input, const double scale, + const std::span output, const uint32_t blocks) { + return compress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks); +} + +__always_inline int decompress_blocks(const uint8_t bit_width, const std::span input, const float scale, + const std::span output, const uint32_t blocks, const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks, sign_values); +} + +__always_inline int decompress_blocks(const uint8_t bit_width, const std::span input, const double scale, + const std::span output, const uint32_t blocks, const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks, sign_values); +} +} + +#endif //PERNIX_HPP diff --git a/include/pernix/x86/avx512vbmi/compat.h b/include/pernix/x86/avx512vbmi/compat.h deleted file mode 100644 index bbc7ad4..0000000 --- a/include/pernix/x86/avx512vbmi/compat.h +++ /dev/null @@ -1,387 +0,0 @@ -#ifndef PERNIX_AVX512_COMPAT_H -#define PERNIX_AVX512_COMPAT_H - -#include -#include -#include -#include - -namespace pernix::internal { -static __always_inline __mmask8 element_mask8(const uint32_t e) { - return static_cast<__mmask8>(e >= 8 ? 0xFFu : ((1u << e) - 1u)); -} - -static __always_inline __mmask16 element_mask16(const uint32_t e) { - return static_cast<__mmask16>(e >= 16 ? 0xFFFFu : ((1u << e) - 1u)); -} - -static __always_inline __mmask32 element_mask32(const uint32_t e) { - return e >= 32 ? 0xFFFFFFFFu : (1u << e) - 1u; -} - -static __always_inline __mmask64 element_mask64(const uint32_t e) { - return e >= 64 ? 0xFFFFFFFFFFFFFFFFull : (1ull << e) - 1ull; -} - -static __always_inline __m512i mm512_loadu_elements_epi64(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); - return a; -#else - return _mm512_maskz_loadu_epi64(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi64(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); - return a; -#else - return _mm256_maskz_loadu_epi64(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi64(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); - return a; -#else - return _mm_maskz_loadu_epi64(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512i mm512_loadu_elements_epi32(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); - return a; -#else - return _mm512_maskz_loadu_epi32(element_mask16(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi32(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); - return a; -#else - return _mm256_maskz_loadu_epi32(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi32(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); - return a; -#else - return _mm_maskz_loadu_epi32(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512i mm512_loadu_elements_epi16(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); - return a; -#else - return _mm512_maskz_loadu_epi16(element_mask32(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi16(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); - return a; -#else - return _mm256_maskz_loadu_epi16(element_mask16(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi16(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); - return a; -#else - return _mm_maskz_loadu_epi16(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512i mm512_loadu_elements_epi8(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); - return a; -#else - return _mm512_maskz_loadu_epi8(element_mask64(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi8(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); - return a; -#else - return _mm256_maskz_loadu_epi8(element_mask32(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi8(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); - return a; -#else - return _mm_maskz_loadu_epi8(element_mask16(e), mem_addr); -#endif -} - -static __always_inline void mm512_storeu_elements_epi64(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); -#else - _mm512_mask_storeu_epi64(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi64(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); -#else - _mm256_mask_storeu_epi64(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi64(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); -#else - _mm_mask_storeu_epi64(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_epi32(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); -#else - _mm512_mask_storeu_epi32(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi32(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); -#else - _mm256_mask_storeu_epi32(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi32(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); -#else - _mm_mask_storeu_epi32(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_epi16(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); -#else - _mm512_mask_storeu_epi16(mem_addr, element_mask32(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi16(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); -#else - _mm256_mask_storeu_epi16(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi16(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); -#else - _mm_mask_storeu_epi16(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_epi8(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); -#else - _mm512_mask_storeu_epi8(mem_addr, element_mask64(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi8(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); -#else - _mm256_mask_storeu_epi8(mem_addr, element_mask32(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi8(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); -#else - _mm_mask_storeu_epi8(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline __m512 mm512_loadu_elements_ps(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512 a = _mm512_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); - return a; -#else - return _mm512_maskz_loadu_ps(element_mask16(e), mem_addr); -#endif -} - -static __always_inline __m256 mm256_loadu_elements_ps(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256 a = _mm256_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); - return a; -#else - return _mm256_maskz_loadu_ps(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128 mm_loadu_elements_ps(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128 a = _mm_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); - return a; -#else - return _mm_maskz_loadu_ps(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512d mm512_loadu_elements_pd(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512d a = _mm512_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); - return a; -#else - return _mm512_maskz_loadu_pd(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m256d mm256_loadu_elements_pd(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256d a = _mm256_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); - return a; -#else - return _mm256_maskz_loadu_pd(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128d mm_loadu_elements_pd(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128d a = _mm_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); - return a; -#else - return _mm_maskz_loadu_pd(element_mask8(e), mem_addr); -#endif -} - -static __always_inline void mm512_storeu_elements_ps(void* mem_addr, const uint32_t e, const __m512 a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) float_t values[16]; - _mm512_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); -#else - _mm512_mask_storeu_ps(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_ps(void* mem_addr, const uint32_t e, const __m256 a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) float_t values[8]; - _mm256_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); -#else - _mm256_mask_storeu_ps(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_ps(void* mem_addr, const uint32_t e, const __m128 a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) float_t values[4]; - _mm_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); -#else - _mm_mask_storeu_ps(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_pd(void* mem_addr, const uint32_t e, const __m512d a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) double_t values[8]; - _mm512_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); -#else - _mm512_mask_storeu_pd(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_pd(void* mem_addr, const uint32_t e, const __m256d a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) double_t values[4]; - _mm256_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); -#else - _mm256_mask_storeu_pd(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_pd(void* mem_addr, const uint32_t e, const __m128d a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) double_t values[2]; - _mm_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); -#else - _mm_mask_storeu_pd(mem_addr, element_mask8(e), a); -#endif -} -} - -#endif //PERNIX_AVX512_COMPAT_H diff --git a/include/pernix/x86/avx512vbmi/packing.h b/include/pernix/x86/avx512vbmi/packing.h deleted file mode 100644 index ba3b132..0000000 --- a/include/pernix/x86/avx512vbmi/packing.h +++ /dev/null @@ -1,327 +0,0 @@ -#ifndef PERNIX_AVX512VBMI_PACKING_H -#define PERNIX_AVX512VBMI_PACKING_H - -#include -#include - -namespace pernix::internal { -namespace m128 { -/** - * @brief Pack 8 16-bit values for bit widths 9 through 16 using VBMI. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = pack_tables_avx512_16; - const __m128i maskv = _mm_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); - const __m128i masked = _mm_and_si128(input, maskv); - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m128i permuted1 = _mm_permutexvar_epi16(tables::get_permute1(), masked); - const __m128i permuted2 = _mm_permutexvar_epi16(tables::get_permute2(), masked); - - const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); - const __m128i shifted2 = _mm_srlv_epi16(permuted2, tables::get_shift2()); - - return _mm_or_si128(shifted1, shifted2); - } else { - const auto [mask1, mask2, mask3] = tables::get_permute_masks(); - - const __m128i permuted1 = _mm_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); - const __m128i permuted2 = _mm_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); - const __m128i permuted3 = _mm_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); - - const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi16(permuted3, tables::get_shift3()); - - return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); - } - } -} - -/** - * @brief Pack 16 8-bit values for bit widths 1 through 8 using VBMI. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - const __m128i maskv = _mm_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); - const __m128i masked = _mm_and_si128(input, maskv); - - if constexpr (BIT_WIDTH == 1) { - return _mm_set1_epi16(static_cast(_mm_cmpgt_epi8_mask(masked, _mm_setzero_si128()))); - } else if constexpr (BIT_WIDTH == 2) { - const __m128i shifted = _mm_srli_epi16(masked, 6); - const __m128i combined = _mm_or_si128(masked, shifted); - - const __m128i shifted2 = _mm_srli_epi32(combined, 12); - const __m128i combined2 = _mm_or_si128(shifted2, combined); - - return _mm_cvtepi32_epi8(combined2); - } else if constexpr (BIT_WIDTH == 3) { - const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); - const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); - - const __m128i pair6 = _mm_or_si128(even, _mm_srli_epi16(odd, 5)); - const __m128i packed12 = _mm_or_si128(pair6, _mm_srli_epi32(pair6, 10)); - - return m128::mm_pack_epi16_avx512vbmi_9to16<12>(_mm_cvtepi32_epi16(packed12)); - } else if constexpr (BIT_WIDTH == 4) { - const __m128i shifted = _mm_srli_epi16(masked, 4); - const __m128i combined = _mm_or_si128(masked, shifted); - - return _mm_cvtepi16_epi8(combined); - } else { - const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); - const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); - - const __m128i shifted = _mm_or_si128(even, _mm_srli_epi16(odd, 8 - BIT_WIDTH)); - return mm_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); - } - } -} - -/** - * @brief Pack 4 32-bit values for bit widths 17 through 24 using VBMI. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { - using tables = pack_tables_avx512_24; - - const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); - const __m128i masked = _mm_and_si128(input, maskv); - - const __m128 permuted1 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute1()); - const __m128 permuted2 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute2()); - const __m128 permuted3 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute3()); - - const __m128i shifted1 = _mm_sllv_epi32(_mm_castps_si128(permuted1), tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi32(_mm_castps_si128(permuted2), tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi32(_mm_castps_si128(permuted3), tables::get_shift3()); - - return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); -} -} // namespace m128 - -namespace m256 { -/** - * @brief Pack 16 16-bit values for bit widths 9 through 16 using VBMI. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = pack_tables_avx512_16; - const __m256i maskv = _mm256_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); - const __m256i masked = _mm256_and_si256(input, maskv); - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m256i permuted1 = _mm256_permutexvar_epi16(tables::get_permute1(), masked); - const __m256i permuted2 = _mm256_permutexvar_epi16(tables::get_permute2(), masked); - - const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_srlv_epi16(permuted2, tables::get_shift2()); - - return _mm256_or_si256(shifted1, shifted2); - } else { - const auto [mask1, mask2, mask3] = tables::get_permute_masks(); - - const __m256i permuted1 = _mm256_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); - const __m256i permuted2 = _mm256_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); - const __m256i permuted3 = _mm256_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); - - const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); - const __m256i shifted3 = _mm256_srlv_epi16(permuted3, tables::get_shift3()); - - return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); - } - } -} - -/** - * @brief Pack 32 8-bit values for bit widths 1 through 8 using VBMI. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - const __m256i maskv = _mm256_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); - const __m256i masked = _mm256_and_si256(input, maskv); - - if constexpr (BIT_WIDTH == 1) { - return _mm256_set1_epi32(static_cast(_mm256_cmpgt_epi8_mask(masked, _mm256_setzero_si256()))); - } else if constexpr (BIT_WIDTH == 2) { - const __m256i shifted = _mm256_srli_epi16(masked, 6); - const __m256i combined = _mm256_or_si256(masked, shifted); - - const __m256i shifted2 = _mm256_srli_epi32(combined, 12); - const __m256i combined2 = _mm256_or_si256(shifted2, combined); - - return _mm256_castsi128_si256(_mm256_cvtepi32_epi8(combined2)); - } else if constexpr (BIT_WIDTH == 3) { - const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); - const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); - - const __m256i pair6 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 5)); - const __m256i packed12 = _mm256_or_si256(pair6, _mm256_srli_epi32(pair6, 10)); - - return m256::mm256_pack_epi16_avx512vbmi_9to16<12>(_mm256_castsi128_si256(_mm256_cvtepi32_epi16(packed12))); - } else if constexpr (BIT_WIDTH == 4) { - const __m256i shifted = _mm256_srli_epi16(masked, 4); - const __m256i combined = _mm256_or_si256(masked, shifted); - - return _mm256_castsi128_si256(_mm256_cvtepi16_epi8(combined)); - } else { - const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); - const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); - - const __m256i shifted = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); - return mm256_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); - } - } -} - -/** - * @brief Pack 8 32-bit values for bit widths 17 through 24 using VBMI. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { - using tables = pack_tables_avx512_24; - - const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); - const __m256i masked = _mm256_and_si256(input, maskv); - - const __m256i permuted1 = _mm256_permutexvar_epi32(tables::get_permute1(), masked); - const __m256i permuted2 = _mm256_permutexvar_epi32(tables::get_permute2(), masked); - const __m256i permuted3 = _mm256_permutexvar_epi32(tables::get_permute3(), masked); - - const __m256i shifted1 = _mm256_sllv_epi32(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi32(permuted2, tables::get_shift2()); - const __m256i shifted3 = _mm256_srlv_epi32(permuted3, tables::get_shift3()); - - return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); -} -} // namespace m256 - -namespace m512 { -/** - * @brief Pack 32 16-bit values for bit widths 9 through 16 using VBMI. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = pack_tables_avx512_16; - const __m512i maskv = _mm512_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); - const __m512i masked = _mm512_and_si512(input, maskv); - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m512i permuted1 = _mm512_permutexvar_epi16(tables::get_permute1(), masked); - const __m512i permuted2 = _mm512_permutexvar_epi16(tables::get_permute2(), masked); - - const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_srlv_epi16(permuted2, tables::get_shift2()); - - return _mm512_or_si512(shifted1, shifted2); - } else { - const auto [mask1, mask2, mask3] = tables::get_permute_masks(); - - const __m512i permuted1 = _mm512_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); - const __m512i permuted2 = _mm512_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); - const __m512i permuted3 = _mm512_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); - - const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); - const __m512i shifted3 = _mm512_srlv_epi16(permuted3, tables::get_shift3()); - - return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); - } - } -} - -/** - * @brief Pack 64 8-bit values for bit widths 1 through 8 using VBMI. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - const __m512i maskv = _mm512_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); - const __m512i masked = _mm512_and_si512(input, maskv); - - if constexpr (BIT_WIDTH == 1) { - return _mm512_set1_epi64(static_cast(_mm512_cmpgt_epi8_mask(masked, _mm512_setzero_si512()))); - } else if constexpr (BIT_WIDTH == 2) { - const __m512i shifted = _mm512_srli_epi16(masked, 6); - const __m512i combined = _mm512_or_si512(masked, shifted); - - const __m512i shifted2 = _mm512_srli_epi32(combined, 12); - const __m512i combined2 = _mm512_or_si512(shifted2, combined); - - return _mm512_castsi128_si512(_mm512_cvtepi32_epi8(combined2)); - } else if constexpr (BIT_WIDTH == 3) { - const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); - const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); - - const __m512i pair6 = _mm512_or_si512(even, _mm512_srli_epi16(odd, 5)); - const __m512i packed12 = _mm512_or_si512(pair6, _mm512_srli_epi32(pair6, 10)); - - return _mm512_castsi256_si512(m256::mm256_pack_epi16_avx512vbmi_9to16<12>(_mm512_cvtepi32_epi16(packed12))); - } else if constexpr (BIT_WIDTH == 4) { - const __m512i shifted = _mm512_srli_epi16(masked, 4); - const __m512i combined = _mm512_or_si512(masked, shifted); - - return _mm512_castsi256_si512(_mm512_cvtepi16_epi8(combined)); - } else { - const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); - const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); - - const __m512i shifted = _mm512_or_si512(even, _mm512_srli_epi16(odd, 8 - BIT_WIDTH)); - return mm512_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); - } - } -} - -/** - * @brief Pack 16 32-bit values for bit widths 17 through 24 using VBMI. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { - using tables = pack_tables_avx512_24; - - const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); - const __m512i masked = _mm512_and_si512(input, maskv); - - const __m512i permuted1 = _mm512_permutexvar_epi32(tables::get_permute1(), masked); - const __m512i permuted2 = _mm512_permutexvar_epi32(tables::get_permute2(), masked); - const __m512i permuted3 = _mm512_permutexvar_epi32(tables::get_permute3(), masked); - - const __m512i shifted1 = _mm512_sllv_epi32(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_sllv_epi32(permuted2, tables::get_shift2()); - const __m512i shifted3 = _mm512_srlv_epi32(permuted3, tables::get_shift3()); - - return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); -} -} // namespace m512 -} // namespace pernix::internal - -#endif // PERNIX_AVX512VBMI_PACKING_H diff --git a/include/pernix/x86/avx512vbmi/unpacking.h b/include/pernix/x86/avx512vbmi/unpacking.h deleted file mode 100644 index 1cfc29c..0000000 --- a/include/pernix/x86/avx512vbmi/unpacking.h +++ /dev/null @@ -1,500 +0,0 @@ -#ifndef PERNIX_AVX512VBMI_UNPACKING_H -#define PERNIX_AVX512VBMI_UNPACKING_H - -#include -#include - -namespace pernix::internal { -namespace m128 { -constexpr __mmask16 kAlternateByteMask16 = 0xAAAAULL; - -__always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0x00ff); - const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); - const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); - return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); -} - -__always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0xff00); - const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); - const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); - return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); -} - -__always_inline static __m128i _mm_slli_epi8(const __m128i a, const int8_t imm8) { - return _mm_sllv_epi8(a, _mm_set1_epi8(imm8)); -} - -__always_inline static __m128i _mm_srli_epi8(const __m128i a, const int imm8) { - const __m128i lo_mask = _mm_set1_epi16(0x00ff); - const __m128i hi_mask = _mm_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m128i lo = _mm_srl_epi16(_mm_and_si128(a, lo_mask), shift); - const __m128i hi = _mm_and_si128(_mm_srl_epi16(a, shift), hi_mask); - - return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); -} - -__always_inline static __m128i _mm_srai_epi8(const __m128i a, const int8_t imm8) { - const __m128i lo_mask = _mm_set1_epi16(0x00ff); - const __m128i hi_mask = _mm_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m128i hi = _mm_and_si128(_mm_sra_epi16(a, shift), hi_mask); - - const __m128i lo_as_hi = _mm_slli_epi16(_mm_and_si128(a, lo_mask), 8); - const __m128i lo = _mm_and_si128(_mm_srli_epi16(_mm_sra_epi16(lo_as_hi, shift), 8), lo_mask); - - return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - if constexpr (BIT_WIDTH == 1) { - const auto value = static_cast<__mmask16>(_mm_cvtsi128_si64(input)); - const __m128i source = _mm_movm_epi8(value); - const __m128i unpacked = _mm_abs_epi8(source); - return unpacked; - } else if constexpr (BIT_WIDTH == 2) { - __m128i values_shift0 = input; - __m128i values_shift2 = _mm_srli_epi16(values_shift0, 2); - const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); - const __m128i values_shift6 = _mm_srli_epi16(values_shift0, 6); - - __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift2); - values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift2); - values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); - - interleave_tmp = _mm_unpacklo_epi8(values_shift4, values_shift6); - values_shift2 = _mm_unpackhi_epi8(values_shift4, values_shift6); - values_shift2 = _mm_unpacklo_epi64(interleave_tmp, values_shift2); - - interleave_tmp = _mm_unpacklo_epi16(values_shift0, values_shift2); - values_shift0 = _mm_unpackhi_epi16(values_shift0, values_shift2); - values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); - values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); - - values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0303)); - - return values_shift0; - } else if constexpr (BIT_WIDTH == 4) { - __m128i values_shift0 = input; - const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); - - const __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift4); - values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift4); - values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); - values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); - - values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0F0F)); - - return values_shift0; - } else { - using tables = unpack_tables_avx512_8; - - const __m128i permuted1 = _mm_permutexvar_epi8(tables::get_permute1(), input); - const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); - - const __m128i shifted1 = _mm_srlv_epi8(permuted1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi8(permuted2, tables::get_shift2()); - - const __mmask16 spill_mask = _mm_cmpneq_epi8_mask(tables::get_shift2(), _mm_setzero_si128()); - __m128i combined = _mm_or_si128(shifted1, _mm_maskz_mov_epi8(spill_mask, shifted2)); - - constexpr uint32_t shift = 8 - BIT_WIDTH; - combined = _mm_slli_epi8(combined, shift); - if (SIGN_VALUES) { - combined = _mm_srai_epi8(combined, shift); - } else { - combined = _mm_srli_epi8(combined, shift); - } - - return combined; - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = unpack_tables_avx512_16; - - const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute1(), input); - - __m128i shifted = _mm_srlv_epi16(permuted, tables::get_shift1()); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); - const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); - shifted = _mm_or_si128(shifted, shifted2); - } - - constexpr uint32_t shift = 16 - BIT_WIDTH; - shifted = _mm_slli_epi16(shifted, shift); - if (SIGN_VALUES) { - shifted = _mm_srai_epi16(shifted, shift); - } else { - shifted = _mm_srli_epi16(shifted, shift); - } - - return shifted; - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m128i mm_unpack_epi32_avx512vbmi_17to24(const __m128i& input) { - using tables = unpack_tables_avx512_24; - - const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute(), input); - - constexpr uint32_t shift = 32 - BIT_WIDTH; - __m128i shifted = _mm_sllv_epi32(permuted, tables::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm_srai_epi32(shifted, shift); - } else { - shifted = _mm_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace m128 - -namespace m256 { -constexpr __mmask32 kAlternateByteMask32 = 0xAAAAAAAAULL; - -__always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0x00ff); - const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); - const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); - return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); -} - -__always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0xff00); - const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); - const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); - return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); -} - -__always_inline static __m256i _mm256_slli_epi8(const __m256i a, const int8_t imm8) { - return _mm256_sllv_epi8(a, _mm256_set1_epi8(imm8)); -} - -__always_inline static __m256i _mm256_srli_epi8(const __m256i a, const int8_t imm8) { - const __m256i lo_mask = _mm256_set1_epi16(0x00ff); - const __m256i hi_mask = _mm256_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m256i lo = _mm256_srl_epi16(_mm256_and_si256(a, lo_mask), shift); - const __m256i hi = _mm256_and_si256(_mm256_srl_epi16(a, shift), hi_mask); - - return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); -} - -__always_inline static __m256i _mm256_srai_epi8(const __m256i a, const int8_t imm8) { - const __m256i lo_mask = _mm256_set1_epi16(0x00ff); - const __m256i hi_mask = _mm256_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m256i hi = _mm256_and_si256(_mm256_sra_epi16(a, shift), hi_mask); - - const __m256i lo_as_hi = _mm256_slli_epi16(_mm256_and_si256(a, lo_mask), 8); - const __m256i lo = _mm256_and_si256(_mm256_srli_epi16(_mm256_sra_epi16(lo_as_hi, shift), 8), lo_mask); - - return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - if constexpr (BIT_WIDTH == 1) { - const auto value = static_cast<__mmask32>(_mm_cvtsi128_si64(_mm256_castsi256_si128(input))); - const __m256i source = _mm256_movm_epi8(value); - const __m256i unpacked = _mm256_abs_epi8(source); - return unpacked; - } else if constexpr (BIT_WIDTH == 2) { - __m256i values_shift0 = input; - __m256i values_shift2 = _mm256_srli_epi16(values_shift0, 2); - const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); - const __m256i values_shift6 = _mm256_srli_epi16(values_shift0, 6); - - __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift2); - values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift2); - values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); - - interleave_tmp = _mm256_unpacklo_epi8(values_shift4, values_shift6); - values_shift2 = _mm256_unpackhi_epi8(values_shift4, values_shift6); - values_shift2 = _mm256_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); - - interleave_tmp = _mm256_unpacklo_epi16(values_shift0, values_shift2); - values_shift0 = _mm256_unpackhi_epi16(values_shift0, values_shift2); - values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); - values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); - - values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0303)); - - return values_shift0; - } else if constexpr (BIT_WIDTH == 4) { - __m256i values_shift0 = input; - const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); - - __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift4); - values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift4); - values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); - values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); - - values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0F0F)); - - return values_shift0; - } else { - using tables = unpack_tables_avx512_8; - - const __m256i permuted1 = _mm256_permutexvar_epi8(tables::get_permute1(), input); - const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); - - const __m256i shifted1 = _mm256_srlv_epi8(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi8(permuted2, tables::get_shift2()); - - const __mmask32 spill_mask = _mm256_cmpneq_epi8_mask(tables::get_shift2(), _mm256_setzero_si256()); - __m256i combined = _mm256_or_si256(shifted1, _mm256_maskz_mov_epi8(spill_mask, shifted2)); - - constexpr uint32_t shift = 8 - BIT_WIDTH; - combined = _mm256_slli_epi8(combined, shift); - if (SIGN_VALUES) { - combined = _mm256_srai_epi8(combined, shift); - } else { - combined = _mm256_srli_epi8(combined, shift); - } - - return combined; - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = unpack_tables_avx512_16; - - const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute1(), input); - - __m256i shifted = _mm256_srlv_epi16(permuted, tables::get_shift1()); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); - const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); - shifted = _mm256_or_si256(shifted, shifted2); - } - - constexpr uint32_t shift = 16 - BIT_WIDTH; - shifted = _mm256_slli_epi16(shifted, shift); - if (SIGN_VALUES) { - shifted = _mm256_srai_epi16(shifted, shift); - } else { - shifted = _mm256_srli_epi16(shifted, shift); - } - - return shifted; - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m256i mm256_unpack_epi32_avx512vbmi_17to24(const __m256i& input) { - using tables = unpack_tables_avx512_24; - - const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute(), input); - - constexpr uint32_t shift = 32 - BIT_WIDTH; - __m256i shifted = _mm256_sllv_epi32(permuted, tables::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm256_srai_epi32(shifted, shift); - } else { - shifted = _mm256_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace m256 - -namespace m512 { -constexpr __mmask64 kAlternateByteMask64 = 0xAAAAAAAAAAAAAAAAULL; - -__always_inline static __m512i _mm512_srlv_epi8(const __m512i a, const __m512i count) { - const __m512i mask = _mm512_set1_epi16(0x00ff); - const __m512i low_half = _mm512_srlv_epi16(_mm512_and_si512(mask, a), _mm512_and_si512(mask, count)); - const __m512i high_half = _mm512_srlv_epi16(a, _mm512_srli_epi16(count, 8)); - return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); -} - -__always_inline static __m512i _mm512_sllv_epi8(const __m512i a, const __m512i count) { - const __m512i mask = _mm512_set1_epi16(0xff00); - const __m512i low_half = _mm512_sllv_epi16(a, _mm512_andnot_si512(mask, count)); - const __m512i high_half = _mm512_sllv_epi16(_mm512_and_si512(mask, a), _mm512_srli_epi16(count, 8)); - return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); -} - -__always_inline static __m512i _mm512_slli_epi8(const __m512i a, const int8_t imm8) { - return _mm512_sllv_epi8(a, _mm512_set1_epi8(imm8)); -} - -__always_inline static __m512i _mm512_srli_epi8(const __m512i a, const int8_t imm8) { - const __m512i lo_mask = _mm512_set1_epi16(0x00ff); - const __m512i hi_mask = _mm512_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m512i lo = _mm512_srl_epi16(_mm512_and_si512(a, lo_mask), shift); - const __m512i hi = _mm512_and_si512(_mm512_srl_epi16(a, shift), hi_mask); - - return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); -} - -__always_inline static __m512i _mm512_srai_epi8(const __m512i a, const int8_t imm8) { - const __m512i lo_mask = _mm512_set1_epi16(0x00ff); - const __m512i hi_mask = _mm512_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m512i hi = _mm512_and_si512(_mm512_sra_epi16(a, shift), hi_mask); - - const __m512i lo_as_hi = _mm512_slli_epi16(_mm512_and_si512(a, lo_mask), 8); - const __m512i lo = _mm512_and_si512(_mm512_srli_epi16(_mm512_sra_epi16(lo_as_hi, shift), 8), lo_mask); - - return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - if constexpr (BIT_WIDTH == 1) { - const auto value = static_cast<__mmask64>(_mm_cvtsi128_si64(_mm512_castsi512_si128(input))); - const __m512i source = _mm512_movm_epi8(value); - const __m512i unpacked = _mm512_abs_epi8(source); - return unpacked; - } else if constexpr (BIT_WIDTH == 2) { - __m512i values_shift0 = input; - __m512i values_shift2 = _mm512_srli_epi16(values_shift0, 2); - const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); - const __m512i values_shift6 = _mm512_srli_epi16(values_shift0, 6); - - __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift2); - values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift2); - values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); - - interleave_tmp = _mm512_unpacklo_epi8(values_shift4, values_shift6); - values_shift2 = _mm512_unpackhi_epi8(values_shift4, values_shift6); - values_shift2 = _mm512_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); - - interleave_tmp = _mm512_unpacklo_epi16(values_shift0, values_shift2); - values_shift0 = _mm512_unpackhi_epi16(values_shift0, values_shift2); - values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x88); - values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); - - values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0303)); - - return values_shift0; - } else if constexpr (BIT_WIDTH == 4) { - __m512i values_shift0 = input; - const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); - - __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift4); - values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift4); - values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x44); - values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); - - values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0F0F)); - - return values_shift0; - } else { - using tables = unpack_tables_avx512_8; - - const __m512i permuted1 = _mm512_permutexvar_epi8(tables::get_permute1(), input); - const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); - - const __m512i shifted1 = _mm512_srlv_epi8(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_sllv_epi8(permuted2, tables::get_shift2()); - - const __mmask64 spill_mask = _mm512_cmpneq_epi8_mask(tables::get_shift2(), _mm512_setzero_si512()); - __m512i combined = _mm512_or_si512(shifted1, _mm512_maskz_mov_epi8(spill_mask, shifted2)); - - constexpr uint32_t shift = 8 - BIT_WIDTH; - combined = _mm512_slli_epi8(combined, shift); - if (SIGN_VALUES) { - combined = _mm512_srai_epi8(combined, shift); - } else { - combined = _mm512_srli_epi8(combined, shift); - } - - return combined; - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = unpack_tables_avx512_16; - - const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute1(), input); - __m512i shifted = _mm512_srlv_epi16(permuted, tables::get_shift1()); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); - const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); - shifted = _mm512_or_si512(shifted, shifted2); - } - - constexpr uint32_t shift = 16 - BIT_WIDTH; - shifted = _mm512_slli_epi16(shifted, shift); - if (SIGN_VALUES) { - shifted = _mm512_srai_epi16(shifted, shift); - } else { - shifted = _mm512_srli_epi16(shifted, shift); - } - - return shifted; - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m512i mm512_unpack_epi32_avx512vbmi_17to24(const __m512i& input) { - using tables = unpack_tables_avx512_24; - - const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute(), input); - __m512i shifted = _mm512_sllv_epi32(permuted, tables::get_shift()); - - constexpr uint32_t shift = 32 - BIT_WIDTH; - if constexpr (SIGN_VALUES) { - shifted = _mm512_srai_epi32(shifted, shift); - } else { - shifted = _mm512_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace m512 -} // namespace pernix::internal - -#endif // PERNIX_AVX512VBMI_UNPACKING_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 57d2f0a..aab5c67 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,12 +1,204 @@ include(GNUInstallDirs) include(CMakePackageConfigHelpers) +add_library(pernix SHARED) +target_sources(pernix + PRIVATE + pernix.cpp + + fallback/fallback_compression.cpp + fallback/fallback_decompression.cpp + + dispatch/select.cpp +) + +target_compile_features(pernix PUBLIC cxx_std_20) + +set_target_properties(pernix PROPERTIES + OUTPUT_NAME "pernix" + VERSION ${NORMALIZED_VERSION} + LINKER_LANGUAGE CXX +) + +target_include_directories(pernix + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/internal +) + +target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_LIB=1 +) + +if (PERNIX_USE_SIMDE) + # target_link_libraries(pernix PUBLIC simde::simde) + target_compile_definitions(pernix PUBLIC PERNIX_USE_SIMDE=1) +endif () + +if (PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + dispatch/cpu_features_x86.cpp + ) + + if (MSVC) + set_source_files_properties( + dispatch/cpu_features_x86.cpp + PROPERTIES + COMPILE_OPTIONS "/arch:AVX512" + ) + else () + set_source_files_properties( + dispatch/cpu_features_x86.cpp + PROPERTIES + COMPILE_OPTIONS "-mavx;-mbmi;-mbmi2;-mavx2;-mavx512f;-mavx512bw;-mavx512vl;-mavx512dq;-mavx512cd;-mavx512vbmi" + ) + endif () +endif () + +if (PERNIX_TARGET_IS_ARM64) + target_sources(pernix + PRIVATE + dispatch/cpu_features_arm.cpp + ) + + set_target_properties(pernix PROPERTIES + COMPILE_OPTIONS "-march=armv8-a+simd+sve2" + ) +endif () + +if (PERNIX_ENABLE_X86_BMI2 AND PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + x86/bmi2/bmi2_compression.cpp + x86/bmi2/bmi2_decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_X86_BMI2=1 + ) + + if (MSVC) + # BMI2 is not exposed with MSVC + else () + set_source_files_properties( + x86/bmi2/bmi2_compression.cpp + x86/bmi2/bmi2_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-mbmi;-mbmi2;-mavx2" + ) + endif () + +endif () + +if (PERNIX_ENABLE_X86_AVX2 AND PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + x86/avx2/avx2_compression.cpp + x86/avx2/avx2_decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_X86_AVX2=1 + ) + + if (MSVC) + set_source_files_properties( + x86/avx2/avx2_compression.cpp + x86/avx2/avx2_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "/arch:AVX2" + ) + else () + set_source_files_properties( + x86/avx2/avx2_compression.cpp + x86/avx2/avx2_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-mavx2" + ) + endif () +endif () + +if (PERNIX_ENABLE_X86_AVX512VBMI AND PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + x86/avx512vbmi/avx512vbmi_compression.cpp + x86/avx512vbmi/avx512vbmi_decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_X86_AVX512_VBMI=1 + ) + + if (MSVC) + set_source_files_properties( + x86/avx512vbmi/avx512vbmi_compression.cpp + x86/avx512vbmi/avx512vbmi_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "/arch:AVX512" + ) + else () + set_source_files_properties( + x86/avx512vbmi/avx512vbmi_compression.cpp + x86/avx512vbmi/avx512vbmi_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-mavx512f;-mavx512bw;-mavx512vl;-mavx512dq;-mavx512cd;-mavx512vbmi" + ) + endif () +endif () + +if (PERNIX_ENABLE_ARM64_NEON AND PERNIX_TARGET_IS_ARM64) + target_sources(pernix + PRIVATE + arm64/neon/compression.cpp + arm64/neon/decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_ARM64_NEON=1 + ) + + set_source_files_properties( + arm64/neon/compression.cpp + arm64/neon/decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-march=armv8-a+simd" + ) +endif () + +if (PERNIX_ENABLE_ARM64_SVE2 AND PERNIX_TARGET_IS_ARM64) + target_sources(pernix + PRIVATE + arm64/sve2/compression.cpp + arm64/sve2/decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_ARM64_SVE2=1 + ) + + set_source_files_properties( + arm64/sve2/compression.cpp + arm64/sve2/decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-march=armv8.2-a+simd+sve2" + ) +endif () + +#[===[ file(GLOB_RECURSE PERNIX_COMMON_SOURCES - CONFIGURE_DEPENDS ./fallback/*.cpp - ./pernix.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/*.h + pernix.cpp + dispatch/select.cpp ) set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) @@ -16,7 +208,6 @@ if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") PERNIX_X86_SOURCES CONFIGURE_DEPENDS ./x86/*.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_X86_SOURCES}) elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") @@ -24,7 +215,6 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") PERNIX_ARM64_NEON_SOURCES CONFIGURE_DEPENDS ./arm64/neon/*.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/arm64/neon/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_NEON_SOURCES}) elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") @@ -32,7 +222,6 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") PERNIX_ARM64_SVE_SOURCES CONFIGURE_DEPENDS ./arm64/sve/*.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE_SOURCES}) elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") @@ -40,11 +229,11 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") PERNIX_ARM64_SVE2_SOURCES CONFIGURE_DEPENDS ./arm64/sve2/*.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve2/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE2_SOURCES}) endif () + add_library(pernix SHARED ${PERNIX_SOURCES}) add_library(pernix::pernix ALIAS pernix) set_target_properties(pernix PROPERTIES @@ -53,9 +242,12 @@ set_target_properties(pernix PROPERTIES ) target_compile_features(pernix PUBLIC cxx_std_20) target_compile_options(pernix PRIVATE ${PERNIX_PRIVATE_COMPILE_OPTIONS}) -target_include_directories(pernix PUBLIC +target_include_directories(pernix + PUBLIC $ $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/internal ) if (PERNIX_ENABLE_LTO) set_target_properties(pernix PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) @@ -160,3 +352,5 @@ if (PERNIX_ENABLE_DOXYGEN) DESTINATION ${CMAKE_INSTALL_DOCDIR}) endif () endif () + +]===] \ No newline at end of file diff --git a/src/arm64/neon/compression.cpp b/src/arm64/neon/compression.cpp index 5968f79..81c1cc2 100644 --- a/src/arm64/neon/compression.cpp +++ b/src/arm64/neon/compression.cpp @@ -1,21 +1,28 @@ +#include #include -namespace pernix { -extern "C" { -int neon_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { - return -1; +namespace pernix::internal { +Kernel select_neon_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"neon", nullptr}; } -int neon_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { - return -1; +Kernel select_neon_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"neon", nullptr}; } -int neon_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { - return -1; +Kernel select_neon_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"neon", nullptr}; } -int neon_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { - return -1; +Kernel select_neon_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"neon", nullptr}; } } -} // namespace pernix diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp index a89f763..94da2fb 100644 --- a/src/arm64/neon/decompression.cpp +++ b/src/arm64/neon/decompression.cpp @@ -1,143 +1,201 @@ +#include #include -namespace pernix { -extern "C" { -#define PERNIX_NEON_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return arm64::neon::neon_decompress_block(input, scale, output); +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_block); \ + return Kernel("neon", &arm64::neon::neon_decompress_block) -#define PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return arm64::neon::neon_decompress_blocks(input, scale, output, blocks); +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_blocks); \ + return Kernel("neon", &arm64::neon::neon_decompress_blocks) -int neon_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_block); \ + return Kernel("neon", &arm64::neon::neon_decompress_block) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_blocks); \ + return Kernel("neon", &arm64::neon::neon_decompress_blocks) + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +Kernel select_neon_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"neon", nullptr}; } } -int neon_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) - PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; +Kernel select_neon_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"neon", nullptr}; } } -int neon_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - switch (bit_width) { - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; +Kernel select_neon_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"neon", nullptr}; } } -int neon_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; +Kernel select_neon_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"neon", nullptr}; } } -#undef PERNIX_NEON_DECOMPRESS_BLOCK_CASE -#undef PERNIX_NEON_DECOMPRESS_BLOCKS_CASE +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 } -} // namespace pernix diff --git a/src/arm64/sve/compression.cpp b/src/arm64/sve/compression.cpp deleted file mode 100644 index e973183..0000000 --- a/src/arm64/sve/compression.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -namespace pernix { -extern "C" { -int sve_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { - return -1; -} - -int sve_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { - return -1; -} - -int sve_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { - return -1; -} - -int sve_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { - return -1; -} -} -} // namespace pernix diff --git a/src/arm64/sve/decompression.cpp b/src/arm64/sve/decompression.cpp deleted file mode 100644 index c6d84be..0000000 --- a/src/arm64/sve/decompression.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -namespace pernix { -extern "C" { -int sve_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { - return -1; -} - -int sve_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { - return -1; -} - -int sve_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { - return -1; -} - -int sve_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { - return -1; -} -} -} // namespace pernix diff --git a/src/arm64/sve2/compression.cpp b/src/arm64/sve2/compression.cpp index 0a55f16..c6d8dd0 100644 --- a/src/arm64/sve2/compression.cpp +++ b/src/arm64/sve2/compression.cpp @@ -1,21 +1,28 @@ +#include #include -namespace pernix { -extern "C" { -int sve2_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { - return -1; +namespace pernix::internal { +Kernel select_sve2_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"sve2", nullptr}; } -int sve2_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { - return -1; +Kernel select_sve2_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"sve2", nullptr}; } -int sve2_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { - return -1; +Kernel select_sve2_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"sve2", nullptr}; } -int sve2_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { - return -1; +Kernel select_sve2_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { + (void)bit_width; + (void)block_size; + return {"sve2", nullptr}; } } -} // namespace pernix diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp index 8429e97..ba796e0 100644 --- a/src/arm64/sve2/decompression.cpp +++ b/src/arm64/sve2/decompression.cpp @@ -1,143 +1,201 @@ +#include #include -namespace pernix { -extern "C" { -#define PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return arm64::sve2::sve2_decompress_block(input, scale, output); +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_block); \ + return Kernel("sve2", &arm64::sve2::sve2_decompress_block) -#define PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return arm64::sve2::sve2_decompress_blocks(input, scale, output, blocks); +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks); \ + return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks) -int sve2_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_block); \ + return Kernel("sve2", &arm64::sve2::sve2_decompress_block) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks); \ + return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks) + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +Kernel select_sve2_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"sve2", nullptr}; } } -int sve2_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) - PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; +Kernel select_sve2_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"sve2", nullptr}; } } -int sve2_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - switch (bit_width) { - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; +Kernel select_sve2_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"sve2", nullptr}; } } -int sve2_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; +Kernel select_sve2_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"sve2", nullptr}; } } -#undef PERNIX_SVE2_DECOMPRESS_BLOCK_CASE -#undef PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 } -} // namespace pernix diff --git a/src/dispatch/cpu_features_arm.cpp b/src/dispatch/cpu_features_arm.cpp new file mode 100644 index 0000000..1b4ac78 --- /dev/null +++ b/src/dispatch/cpu_features_arm.cpp @@ -0,0 +1,23 @@ +#include + +namespace pernix::internal { +CpuFeatures detect_cpu_features() { + CpuFeatures features{}; + + // neon +#if defined(__aarch64__) || defined(_M_ARM64) + features.neon = true; +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + features.neon = true; +#endif + + // sve +#if defined(__aarch64__) || defined(_M_ARM64) +#ifdef __ARM_FEATURE_SVE + features.sve = true; +#endif +#endif + + return features; +} +} \ No newline at end of file diff --git a/src/dispatch/cpu_features_x86.cpp b/src/dispatch/cpu_features_x86.cpp new file mode 100644 index 0000000..f43245c --- /dev/null +++ b/src/dispatch/cpu_features_x86.cpp @@ -0,0 +1,95 @@ +#include + +#include + +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) +#if defined(_MSC_VER) +#include +#else +#include +#include +#endif + + +namespace pernix::internal { +namespace { +#if defined(_MSC_VER) + +void cpuid(int out[4], int leaf, int subleaf) { + __cpuidex(out, leaf, subleaf); +} + +std::uint64_t xgetbv(unsigned int index) { + return _xgetbv(index); +} + +#else + +void cpuid(int out[4], int leaf, int subleaf) { + __cpuid_count(leaf, subleaf, out[0], out[1], out[2], out[3]); +} + +std::uint64_t xgetbv(unsigned int index) { + return _xgetbv(index); +} + +#endif + +bool bit_set(int value, int bit) { + return (value & (1 << bit)) != 0; +} +} // namespace + +CpuFeatures detect_cpu_features() { + CpuFeatures features{}; + + int regs[4]{}; + + cpuid(regs, 1, 0); + + const bool osxsave = bit_set(regs[2], 27); + const bool avx = bit_set(regs[2], 28); + + if (!osxsave || !avx) { + return features; + } + + const std::uint64_t xcr0 = xgetbv(0); + + const bool xmm_enabled = (xcr0 & 0x2) != 0; + const bool ymm_enabled = (xcr0 & 0x4) != 0; + const bool zmm_enabled = + (xcr0 & 0x20) != 0 && + (xcr0 & 0x40) != 0 && + (xcr0 & 0x80) != 0; + + if (!xmm_enabled || !ymm_enabled) { + return features; + } + + cpuid(regs, 7, 0); + + features.avx2 = bit_set(regs[1], 5); + features.bmi2 = bit_set(regs[1], 8); + + if (zmm_enabled) { + features.avx512f = bit_set(regs[1], 16); + features.avx512dq = bit_set(regs[1], 29); + features.avx512bw = bit_set(regs[1], 30); + features.avx512vl = bit_set(regs[1], 31); + features.avx512vbmi = bit_set(regs[2], 1); + } + + return features; +} +} // namespace pernix::internal +#else + +namespace pernix::internal { +CpuFeatures detect_cpu_features() { + return {}; +} +} // namespace pernix::internal + +#endif + diff --git a/src/dispatch/select.cpp b/src/dispatch/select.cpp new file mode 100644 index 0000000..b619af1 --- /dev/null +++ b/src/dispatch/select.cpp @@ -0,0 +1,684 @@ +#include +#include + +namespace pernix::internal { +Kernel select_compress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_block_f32(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_block_f32(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_compress_block_f32(bit_width, block_size); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_compress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_blocks_f32(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_blocks_f32(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_compress_blocks_f32(bit_width, block_size); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_compress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_block_f64(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_block_f64(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_compress_block_f64(bit_width, block_size); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_compress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_blocks_f64(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_blocks_f64(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_compress_blocks_f64(bit_width, block_size); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_decompress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_decompress_block_f32(bit_width, block_size, sign_values); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_decompress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_decompress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_decompress_block_f64(bit_width, block_size, sign_values); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_decompress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: + return select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif + default: + return {"invalid_backend", nullptr}; + } +} + +Kernel select_auto_compress_block_f32(uint8_t bit_width, uint32_t block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_block_f32(bit_width, block_size); +} + +Kernel select_auto_compress_blocks_f32(uint8_t bit_width, uint32_t block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_blocks_f32(bit_width, block_size); +} + +Kernel select_auto_compress_block_f64(uint8_t bit_width, uint32_t block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_block_f64(bit_width, block_size); +} + +Kernel select_auto_compress_blocks_f64(uint8_t bit_width, uint32_t block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_blocks_f64(bit_width, block_size); +} + +Kernel select_auto_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); +} + +Kernel select_auto_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); +} + +Kernel select_auto_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); +} + +Kernel select_auto_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve) { + if (auto kernel = select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); +} +} diff --git a/src/fallback/compression.cpp b/src/fallback/compression.cpp deleted file mode 100644 index 1ed56e8..0000000 --- a/src/fallback/compression.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return compress_block_fallback(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return compress_blocks_fallback(input, scale, output, blocks); - -int compress_block_fallback(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int compress_block_fallback_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int compress_blocks_fallback(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int compress_blocks_fallback_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus diff --git a/src/fallback/decompression.cpp b/src/fallback/decompression.cpp deleted file mode 100644 index e43b1df..0000000 --- a/src/fallback/decompression.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return decompress_block_fallback(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return decompress_blocks_fallback(input, scale, output, blocks); - -int decompress_block_fallback(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int decompress_block_fallback_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int decompress_blocks_fallback(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int decompress_blocks_fallback_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_DECOMPRESS_BLOCK_CASE -#undef PERNIX_DECOMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus \ No newline at end of file diff --git a/src/fallback/fallback_compression.cpp b/src/fallback/fallback_compression.cpp new file mode 100644 index 0000000..eeaa34d --- /dev/null +++ b/src/fallback/fallback_compression.cpp @@ -0,0 +1,194 @@ +#include +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("fallback", &compress_block_fallback) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("fallback", &compress_blocks_fallback) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("fallback", &compress_block_fallback) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("fallback", &compress_blocks_fallback) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +Kernel select_fallback_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"fallback", nullptr}; + } +} + +Kernel select_fallback_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"fallback", nullptr}; + } +} + +Kernel select_fallback_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"fallback", nullptr}; + } +} + +Kernel select_fallback_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"fallback", nullptr}; + } +} + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/fallback/fallback_decompression.cpp b/src/fallback/fallback_decompression.cpp new file mode 100644 index 0000000..232b831 --- /dev/null +++ b/src/fallback/fallback_decompression.cpp @@ -0,0 +1,201 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_block_fallback); \ + return Kernel("fallback", &decompress_block_fallback) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_blocks_fallback); \ + return Kernel("fallback", &decompress_blocks_fallback) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_block_fallback); \ + return Kernel("fallback", &decompress_block_fallback) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_blocks_fallback); \ + return Kernel("fallback", &decompress_blocks_fallback) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +Kernel select_fallback_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"fallback", nullptr}; + } +} + +Kernel select_fallback_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"fallback", nullptr}; + } +} + +Kernel select_fallback_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"fallback", nullptr}; + } +} + +Kernel select_fallback_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"fallback", nullptr}; + } +} + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/internal/pernix/arm64/neon/common.h b/src/internal/pernix/arm64/neon/common.h new file mode 100644 index 0000000..2677416 --- /dev/null +++ b/src/internal/pernix/arm64/neon/common.h @@ -0,0 +1,223 @@ +#ifndef PERNIX_ARM64_NEON_COMMON_H +#define PERNIX_ARM64_NEON_COMMON_H + +#include + +#include + +namespace pernix::arm64::neon::internal { + struct float64x2x8_t { + float64x2_t val[8]; + }; + + static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; + } + +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x4(const int8x16_t &input) { + const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); + const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); + + return { + { + vmovl_s16(vget_low_s16(s16_lo)), + vmovl_s16(vget_high_s16(s16_lo)), + vmovl_s16(vget_low_s16(s16_hi)), + vmovl_s16(vget_high_s16(s16_hi)), + } + }; + } + +__always_inline int32x4x2_t neon_convert_int16x8_int32x4x2(const int16x8_t &input) { + return { + { + vmovl_s16(vget_low_s16(input)), + vmovl_s16(vget_high_s16(input)), + } + }; + } + +__always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t &input, const float32x4_t &scale) { + return { + { + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[2]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[3]), scale), + } + }; + } + +__always_inline float32x4x2_t neon_dequantize_epi32(const int32x4x2_t &input, const float32x4_t &scale) { + return { + { + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + } + }; + } + +__always_inline float32x4_t neon_dequantize_epi32(const int32x4_t &input, const float32x4_t &scale) { + return vmulq_f32(vcvtq_f32_s32(input), scale); + } + +__always_inline float64x2_t neon_dequantize_epi32_f64(const int32x2_t &input, const float64x2_t &scale) { + return vmulq_f64(vcvtq_f64_s64(vmovl_s32(input)), scale); + } + +__always_inline float64x2x2_t neon_dequantize_epi32_f64(const int32x4_t &input, const float64x2_t &scale) { + return { + { + neon_dequantize_epi32_f64(vget_low_s32(input), scale), + neon_dequantize_epi32_f64(vget_high_s32(input), scale), + } + }; + } + +__always_inline float64x2x4_t neon_dequantize_epi32_f64(const int32x4x2_t &input, const float64x2_t &scale) { + const float64x2x2_t dequantized_low = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized_high = neon_dequantize_epi32_f64(input.val[1], scale); + + return { + { + dequantized_low.val[0], + dequantized_low.val[1], + dequantized_high.val[0], + dequantized_high.val[1], + } + }; + } + +__always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t &input, const float64x2_t &scale) { + const float64x2x2_t dequantized0 = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized1 = neon_dequantize_epi32_f64(input.val[1], scale); + const float64x2x2_t dequantized2 = neon_dequantize_epi32_f64(input.val[2], scale); + const float64x2x2_t dequantized3 = neon_dequantize_epi32_f64(input.val[3], scale); + + return { + { + dequantized0.val[0], + dequantized0.val[1], + dequantized1.val[0], + dequantized1.val[1], + dequantized2.val[0], + dequantized2.val[1], + dequantized3.val[0], + dequantized3.val[1], + } + }; + } + +__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t *input, const uint32_t tail_bytes_count) { + uint8_t buffer[16] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u8(buffer); + } + +__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t *input, const uint32_t tail_bytes_count) { + uint16_t buffer[8] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u16(buffer); + } + +__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t *input, const uint32_t tail_bytes_count) { + uint32_t buffer[4] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u32(buffer); + } + +__always_inline float32x4_t neon_load_tail_elements_f32(const uint8_t *input, const uint32_t tail_elements) { + float32_t buffer[4] = {0.0f}; + std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); + return vld1q_f32(buffer); + } + +__always_inline float64x2_t neon_load_tail_elements_f64(const uint8_t *input, const uint32_t tail_elements) { + float64_t buffer[2] = {0.0}; + std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); + return vld1q_f64(buffer); + } + +__always_inline void neon_store_tail_elements_int8(uint8_t *output, const uint8x16x4_t &data, + const uint32_t tail_elements) { + uint8_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u8(buffer + i * 16, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint8_t)); + } + +__always_inline void neon_store_tail_elements_int16(uint16_t *output, const uint16x8x4_t &data, + const uint32_t tail_elements) { + uint16_t buffer[8 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u16(buffer + i * 8, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint16_t)); + } + +__always_inline void neon_store_tail_elements_int32(uint32_t *output, const uint32x4x4_t &data, + const uint32_t tail_elements) { + uint32_t buffer[4 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint32_t)); + } + +__always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4x4_t &data, + const uint32_t tail_elements) { + float32_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); + } + +__always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4x2_t &data, + const uint32_t tail_elements) { + float32_t buffer[8 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); + } + +__always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4_t &data, + const uint32_t tail_elements) { + float32_t buffer[4]; + vst1q_f32(buffer, data); + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); + } + +__always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x4_t &data, + const uint32_t tail_elements) { + float64_t buffer[2 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); + } + +__always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x2_t &data, + const uint32_t tail_elements) { + float64_t buffer[2 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); + } + +__always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x8_t &data, + const uint32_t tail_elements) { + float64_t buffer[2 * 8]; + for (uint32_t i = 0; i < 8; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); + } +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_COMMON_H diff --git a/include/pernix/arm64/neon/compression.h b/src/internal/pernix/arm64/neon/compression.h similarity index 79% rename from include/pernix/arm64/neon/compression.h rename to src/internal/pernix/arm64/neon/compression.h index 6e49348..f88fbb6 100644 --- a/include/pernix/arm64/neon/compression.h +++ b/src/internal/pernix/arm64/neon/compression.h @@ -9,7 +9,7 @@ namespace pernix::arm64::neon { namespace internal { -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -17,7 +17,7 @@ __always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, return -1; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -25,7 +25,7 @@ __always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, return -1; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -33,7 +33,7 @@ __always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input return -1; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -41,7 +41,7 @@ __always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, return -1; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -49,7 +49,7 @@ __always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, return -1; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -58,7 +58,7 @@ __always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input } } // namespace internal -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -72,7 +72,7 @@ __always_inline int neon_compress_block(const uint8_t* __restrict__ input, const return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_compress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -86,7 +86,7 @@ __always_inline int neon_compress_block(const uint8_t* __restrict__ input, const return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int neon_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { @@ -102,7 +102,7 @@ int neon_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int neon_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { @@ -116,26 +116,6 @@ int neon_compress_blocks(const uint8_t* __restrict__ input, const double_t scale } return 0; } - -#ifdef __cplusplus -extern "C" { -#endif - -int neon_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - - -int neon_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); - -int neon_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -int neon_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -#endif } // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/include/pernix/arm64/neon/decompression.h b/src/internal/pernix/arm64/neon/decompression.h similarity index 77% rename from include/pernix/arm64/neon/decompression.h rename to src/internal/pernix/arm64/neon/decompression.h index 583948f..375adc2 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/src/internal/pernix/arm64/neon/decompression.h @@ -10,9 +10,10 @@ namespace pernix::arm64::neon { namespace internal { -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -22,7 +23,9 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input for (uint32_t i = 0; i < iterations_16; ++i) { const uint8x16_t source = vld1q_u8(input); - const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8 + (source); const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); @@ -36,8 +39,11 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input } if constexpr (remaining_elements > 0) { - const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + const uint8x16_t tail_source = neon_load_tail_elements_int8( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8 + (tail_source); const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); @@ -48,9 +54,10 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; @@ -60,7 +67,9 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu for (uint32_t i = 0; i < iterations_8; ++i) { const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); - const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16 + (source); const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); @@ -74,8 +83,11 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu } if constexpr (remaining_elements > 0) { - const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); + const uint16x8_t tail_source = neon_load_tail_elements_int16( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16 + (tail_source); const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); @@ -86,9 +98,10 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_4 = elements_per_block / 4; @@ -131,7 +144,8 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp if constexpr (tail_bit_offset == 0) { tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } else { - tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24( + tail_source); } const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); @@ -142,9 +156,10 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -154,7 +169,9 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input for (uint32_t i = 0; i < iterations_16; ++i) { const uint8x16_t source = vld1q_u8(input); - const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8 + (source); const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); const float64x2x8_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); @@ -168,8 +185,11 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input } if constexpr (remaining_elements > 0) { - const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + const uint8x16_t tail_source = neon_load_tail_elements_int8( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8 + (tail_source); const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); const float64x2x8_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); @@ -180,9 +200,10 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; @@ -192,7 +213,9 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu for (uint32_t i = 0; i < iterations_8; ++i) { const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); - const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16 + (source); const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); @@ -206,8 +229,11 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu } if constexpr (remaining_elements > 0) { - const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); + const uint16x8_t tail_source = neon_load_tail_elements_int16( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16 + (tail_source); const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); @@ -218,9 +244,10 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_4 = elements_per_block / 4; @@ -264,7 +291,8 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp if constexpr (tail_bit_offset == 0) { tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } else { - tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24( + tail_source); } const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); @@ -276,37 +304,46 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp } } // namespace internal -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int neon_decompress_block(const void* __restrict__ input, const float_t scale, + void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::neon_decompress_block_1to8(input, scale, output); + return internal::neon_decompress_block_1to8(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::neon_decompress_block_9to16(input, scale, output); + return internal::neon_decompress_block_9to16(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::neon_decompress_block_17to24(input, scale, output); + return internal::neon_decompress_block_17to24(typed_input, scale, typed_output); } return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int neon_decompress_block(const void* __restrict__ input, const double_t scale, + void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::neon_decompress_block_1to8(input, scale, output); + return internal::neon_decompress_block_1to8(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::neon_decompress_block_9to16(input, scale, output); + return internal::neon_decompress_block_9to16(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::neon_decompress_block_17to24(input, scale, output); + return internal::neon_decompress_block_17to24(typed_input, scale, typed_output); } return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; +int neon_decompress_blocks(const void* __restrict__ input, const float_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + float_t* block_output = typed_output; for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); @@ -317,37 +354,23 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scal return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; +int neon_decompress_blocks(const void* __restrict__ input, const double_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + double_t* block_output = typed_output; for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } - return 0; -} - -#ifdef __cplusplus -extern "C" { -#endif - -int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); -int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus + return 0; } -#endif } // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/packing.h b/src/internal/pernix/arm64/neon/packing.h similarity index 100% rename from include/pernix/arm64/neon/packing.h rename to src/internal/pernix/arm64/neon/packing.h diff --git a/include/pernix/arm64/neon/tables.h b/src/internal/pernix/arm64/neon/tables.h similarity index 100% rename from include/pernix/arm64/neon/tables.h rename to src/internal/pernix/arm64/neon/tables.h diff --git a/src/internal/pernix/arm64/neon/unpacking.h b/src/internal/pernix/arm64/neon/unpacking.h new file mode 100644 index 0000000..e70fbbc --- /dev/null +++ b/src/internal/pernix/arm64/neon/unpacking.h @@ -0,0 +1,101 @@ +#ifndef PERNIX_ARM64_NEON_UNPACKING_H +#define PERNIX_ARM64_NEON_UNPACKING_H + +#include +#include + +using namespace pernix::arm64::neon::internal; + +namespace pernix::arm64::neon::internal::b128 { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t &input) { + if constexpr (BIT_WIDTH == 8) { + return vreinterpretq_s8_u8(input); + } else if constexpr (BIT_WIDTH == 1) { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + const uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + return vreinterpretq_s8_u8(vandq_u8(shifted, vdupq_n_u8(1))); + } else { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + + uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute2.data())); + + shifted = vorrq_u8(shifted, vshlq_u8(permuted2_u8, vld1q_s8(tables::shift2.data()))); + } + + constexpr int shift = 8 - BIT_WIDTH; + shifted = vshlq_n_u8(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s8(vreinterpretq_s8_u8(shifted), vdupq_n_s8(-shift)); + } else { + return vreinterpretq_s8_u8(vshlq_u8(shifted, vdupq_n_s8(-shift))); + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t &input) { + if constexpr (BIT_WIDTH == 16) { + return vreinterpretq_s16_u16(input); + } else { + using tables = table_unpacking; + + const uint8x16_t input_u8 = vreinterpretq_u8_u16(input); + + const uint8x16_t permuted1_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute1.data())); + + uint16x8_t shifted = vshlq_u16(vreinterpretq_u16_u8(permuted1_u8), vld1q_s16(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute2.data())); + + const uint16x8_t shifted2 = vshlq_u16(vreinterpretq_u16_u8(permuted2_u8), + vld1q_s16(tables::shift2.data())); + + shifted = vorrq_u16(shifted, shifted2); + } + + constexpr int shift = 16 - BIT_WIDTH; + shifted = vshlq_n_u16(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s16(vreinterpretq_s16_u16(shifted), vdupq_n_s16(-shift)); + } else { + return vreinterpretq_s16_u16(vshlq_u16(shifted, vdupq_n_s16(-shift))); + } + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t &input) { + using tables = table_unpacking; + + const uint8x16_t input_8 = vreinterpretq_u8_u32(input); + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input_8, vld1q_u8(tables::permute.data())); + + const uint32x4_t value = vshlq_u32(vreinterpretq_u32_u8(permuted_u8), vld1q_s32(tables::shift.data())); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); + } + } +} // namespace pernix::arm64::neon::internal::b128 + +#endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve2/compression.h b/src/internal/pernix/arm64/sve2/compression.h similarity index 52% rename from include/pernix/arm64/sve2/compression.h rename to src/internal/pernix/arm64/sve2/compression.h index 4e4627d..72d9229 100644 --- a/include/pernix/arm64/sve2/compression.h +++ b/src/internal/pernix/arm64/sve2/compression.h @@ -12,48 +12,37 @@ template inline constexpr bool sve2_compression_unimplemented_v = false; } // namespace internal -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_compress_block(const float_t*, float_t, uint8_t*) { - static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); return -1; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_compress_block(const double_t*, double_t, uint8_t*) { - static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); return -1; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { - static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); return -1; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { - static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); return -1; } - -#ifdef __cplusplus -extern "C" { -#endif - -int sve2_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); -int sve2_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); -int sve2_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); -int sve2_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -#endif } // namespace pernix #endif // PERNIX_ARM64_SVE2_COMPRESSION_H diff --git a/include/pernix/arm64/sve2/decompression.h b/src/internal/pernix/arm64/sve2/decompression.h similarity index 76% rename from include/pernix/arm64/sve2/decompression.h rename to src/internal/pernix/arm64/sve2/decompression.h index 2128ff2..350a196 100644 --- a/include/pernix/arm64/sve2/decompression.h +++ b/src/internal/pernix/arm64/sve2/decompression.h @@ -17,13 +17,15 @@ template return (elements * BIT_WIDTH + 7) / 8; } -[[nodiscard]] __always_inline svuint8_t sve2_load_packed_bytes(const uint8_t* __restrict__ input, const uint32_t bytes) { +[[nodiscard]] __always_inline svuint8_t sve2_load_packed_bytes(const uint8_t* __restrict__ input, + const uint32_t bytes) { const svbool_t pg = svwhilelt_b8(uint64_t{0}, static_cast(bytes)); return svld1_u8(pg, input); } template -__always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloat32_t scale_v, float_t* __restrict__ output, +__always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloat32_t scale_v, + float_t* __restrict__ output, const uint32_t count) { alignas(64) std::vector temp(svcntb()); @@ -50,7 +52,8 @@ __always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloa } template -__always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double_t scale, double_t* __restrict__ output, +__always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double_t scale, + double_t* __restrict__ output, const uint32_t count) { std::vector temp(svcntb()); @@ -66,7 +69,8 @@ __always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double } template -__always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfloat32_t scale_v, float_t* __restrict__ output, +__always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfloat32_t scale_v, + float_t* __restrict__ output, const uint32_t count) { alignas(64) std::vector temp(svcnth()); @@ -81,8 +85,9 @@ __always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfl const svint32_t widened = svld1sh_s32(pg, temp.data() + offset); dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); } else { - const svuint32_t widened = svld1uh_u32(pg, reinterpret_cast(temp.data() + offset)); - dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + const svuint32_t widened = + svld1uh_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); } svst1_f32(pg, output + offset, dequantized); @@ -92,7 +97,8 @@ __always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfl } template -__always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const double_t scale, double_t* __restrict__ output, +__always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const double_t scale, + double_t* __restrict__ output, const uint32_t count) { std::vector temp(svcnth()); @@ -108,7 +114,8 @@ __always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const doub } template -__always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfloat32_t scale_v, float_t* __restrict__ output, +__always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfloat32_t scale_v, + float_t* __restrict__ output, const uint32_t count) { const svbool_t pg = svwhilelt_b32(uint64_t{0}, static_cast(count)); @@ -123,7 +130,8 @@ __always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfl } template -__always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const double_t scale, double_t* __restrict__ output, +__always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const double_t scale, + double_t* __restrict__ output, const uint32_t count) { std::vector temp(svcntw()); @@ -138,9 +146,10 @@ __always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const doub } } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const auto lanes = static_cast(svcntb()); @@ -166,20 +175,23 @@ __always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input const uint8_t* chunk_input = input + input_bit_offset / 8; const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); - const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + const svint8_t unpacked = sve2_unpack_epi8_1to8 + (source, permute, shift, spill_permute, spill_shift); sve2_store_dequantized_i8_f32(unpacked, scale_v, output + processed_elements, count); processed_elements += count; - input_bit_offset += count * BIT_WIDTH; + input_bit_offset += count * BIT_WIDTH; } return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const auto lanes = static_cast(svcnth()); @@ -204,22 +216,25 @@ __always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ inpu const uint32_t bytes = packed_bytes(count); const uint8_t* chunk_input = input + input_bit_offset / 8; - const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); const svint16_t unpacked = - sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + sve2_unpack_epi16_9to16 + (svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); sve2_store_dequantized_i16_f32(unpacked, scale_v, output + processed_elements, count); processed_elements += count; - input_bit_offset += count * BIT_WIDTH; + input_bit_offset += count * BIT_WIDTH; } return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const auto lanes = static_cast(svcntw()); @@ -246,15 +261,16 @@ __always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ inp sve2_store_dequantized_i32_f32(unpacked, scale_v, output + processed_elements, count); processed_elements += count; - input_bit_offset += count * BIT_WIDTH; + input_bit_offset += count * BIT_WIDTH; } return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const auto lanes = static_cast(svcntb()); @@ -278,20 +294,23 @@ __always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input const uint8_t* chunk_input = input + input_bit_offset / 8; const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); - const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + const svint8_t unpacked = sve2_unpack_epi8_1to8 + (source, permute, shift, spill_permute, spill_shift); sve2_store_dequantized_i8_f64(unpacked, scale, output + processed_elements, count); processed_elements += count; - input_bit_offset += count * BIT_WIDTH; + input_bit_offset += count * BIT_WIDTH; } return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const auto lanes = static_cast(svcnth()); @@ -314,22 +333,25 @@ __always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ inpu const uint32_t bytes = packed_bytes(count); const uint8_t* chunk_input = input + input_bit_offset / 8; - const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); const svint16_t unpacked = - sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + sve2_unpack_epi16_9to16 + (svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); sve2_store_dequantized_i16_f64(unpacked, scale, output + processed_elements, count); processed_elements += count; - input_bit_offset += count * BIT_WIDTH; + input_bit_offset += count * BIT_WIDTH; } return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const auto lanes = static_cast(svcntw()); @@ -354,84 +376,77 @@ __always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ inp sve2_store_dequantized_i32_f64(unpacked, scale, output + processed_elements, count); processed_elements += count; - input_bit_offset += count * BIT_WIDTH; + input_bit_offset += count * BIT_WIDTH; } return 0; } -} // namespace internal +} // namespace internal -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +int sve2_decompress_block(const void* __restrict__ input, const float_t scale, void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::sve2_decompress_block_1to8(input, scale, output); + return internal::sve2_decompress_block_1to8(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::sve2_decompress_block_9to16(input, scale, output); + return internal::sve2_decompress_block_9to16(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::sve2_decompress_block_17to24(input, scale, output); + return internal::sve2_decompress_block_17to24(typed_input, scale, typed_output); } } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +int sve2_decompress_block(const void* __restrict__ input, const double_t scale, + void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::sve2_decompress_block_1to8(input, scale, output); + return internal::sve2_decompress_block_1to8(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::sve2_decompress_block_9to16(input, scale, output); + return internal::sve2_decompress_block_9to16(typed_input, scale, typed_output); } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::sve2_decompress_block_17to24(input, scale, output); + return internal::sve2_decompress_block_17to24(typed_input, scale, typed_output); } } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; +int sve2_decompress_blocks(const void* __restrict__ input, const float_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + float_t* block_output = typed_output; for (uint32_t block = 0; block < blocks; ++block) { sve2_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; +int sve2_decompress_blocks(const void* __restrict__ input, const double_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + double_t* block_output = typed_output; for (uint32_t block = 0; block < blocks; ++block) { sve2_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; } - -#ifdef __cplusplus -extern "C" { -#endif - -int sve2_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -int sve2_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -int sve2_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -int sve2_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -#endif -} // namespace pernix::arm64::sve2 +} // namespace pernix::arm64::sve2 #endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve2/packing.h b/src/internal/pernix/arm64/sve2/packing.h similarity index 74% rename from include/pernix/arm64/sve2/packing.h rename to src/internal/pernix/arm64/sve2/packing.h index 789b4d7..5cf2355 100644 --- a/include/pernix/arm64/sve2/packing.h +++ b/src/internal/pernix/arm64/sve2/packing.h @@ -4,8 +4,8 @@ #include namespace pernix::arm64::sve2::internal { -template -inline constexpr bool packing_unimplemented_v = false; + template + inline constexpr bool packing_unimplemented_v = false; } // namespace pernix::arm64::sve2::internal #endif // PERNIX_ARM64_SVE2_PACKING_H diff --git a/src/internal/pernix/arm64/sve2/tables.h b/src/internal/pernix/arm64/sve2/tables.h new file mode 100644 index 0000000..8ef61d8 --- /dev/null +++ b/src/internal/pernix/arm64/sve2/tables.h @@ -0,0 +1,119 @@ +#ifndef PERNIX_ARM64_SVE2_TABLES_H +#define PERNIX_ARM64_SVE2_TABLES_H + +#include + +#include + +namespace pernix::arm64::sve2::internal { + template + struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svbool_t pg_b8() { return svptrue_b8(); } + + static svbool_t pg_b16() { return svptrue_b16(); } + + static svbool_t pg_b32() { return svptrue_b32(); } + }; + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) + struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 1); + } + + static svuint8_t shift() { + const svbool_t pg = svptrue_b8(); + return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); + } + + static svuint8_t spill_shift() { + const svbool_t pg = svptrue_b8(); + return svsub_u8_x(pg, svdup_n_u8(8), shift()); + } + }; + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) + struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); + const svuint8_t byte = svand_n_u8_x(pg, lane, 1); + + svuint8_t first; + if constexpr (BIT_WIDTH == 16) { + first = svlsl_n_u8_x(pg, elem, 1); + } else { + constexpr uint8_t extra_bits = BIT_WIDTH - 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); + first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); + } + + return svadd_u8_x(pg, first, byte); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 2); + } + + static svuint16_t shift() { + const svbool_t pg = svptrue_b16(); + return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); + } + + static svuint16_t spill_shift() { + const svbool_t pg = svptrue_b16(); + const svuint16_t bit_shift = shift(); + const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); + return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); + } + }; + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) + struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); + const svuint8_t byte = svand_n_u8_x(pg, lane, 3); + + svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); + if constexpr (BIT_WIDTH % 8u != 0) { + constexpr uint8_t extra_bits = BIT_WIDTH % 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low_bits = + svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); + first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); + } + + return svadd_u8_x(pg, first, byte); + } + + static svuint32_t shift() { + const svbool_t pg = svptrue_b32(); + return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), + START_BIT_OFFSET), 7); + } + }; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/src/internal/pernix/arm64/sve2/unpacking.h b/src/internal/pernix/arm64/sve2/unpacking.h new file mode 100644 index 0000000..8bb7e3c --- /dev/null +++ b/src/internal/pernix/arm64/sve2/unpacking.h @@ -0,0 +1,94 @@ +#ifndef PERNIX_ARM64_SVE2_UNPACKING_H +#define PERNIX_ARM64_SVE2_UNPACKING_H + +#include + +#include "tables.h" + +namespace pernix::arm64::sve2::internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, + const svuint8_t spill_permute, const svuint8_t spill_shift) { + if constexpr (BIT_WIDTH == 8) { + return svreinterpret_s8(input); + } else { + const svbool_t pg = svptrue_b8(); + + const svuint8_t permuted = svtbl_u8(input, permute); + svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); + const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); + unpacked = svorr_u8_x(pg, unpacked, spill_shifted); + } + + if constexpr (BIT_WIDTH == 1) { + unpacked = svand_n_u8_x(pg, unpacked, 1); + return svreinterpret_s8(unpacked); + } else { + constexpr int sign_shift = 8 - BIT_WIDTH; + + unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); + } else { + return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); + } + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, + const svuint16_t shift, + const svuint8_t spill_permute, const svuint16_t spill_shift) { + if constexpr (BIT_WIDTH == 16) { + return svreinterpret_s16(input); + } else { + const svbool_t pg = svptrue_b16(); + + const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); + svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); + const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), + spill_shift); + shifted = svorr_u16_x(pg, shifted, spill_shifted); + } + + constexpr int sign_shift = 16 - BIT_WIDTH; + shifted = svlsl_n_u16_x(pg, shifted, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); + } else { + return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); + } + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { + using table = table_unpacking; + + const svbool_t pg = svptrue_b32(); + const svuint8_t permuted = svtbl_u8(input, table::permute()); + const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + } + } +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/src/internal/pernix/dispatch/cpu_features.h b/src/internal/pernix/dispatch/cpu_features.h new file mode 100644 index 0000000..bcf3e7d --- /dev/null +++ b/src/internal/pernix/dispatch/cpu_features.h @@ -0,0 +1,26 @@ +#ifndef PERNIX_CPU_FEATURES_H +#define PERNIX_CPU_FEATURES_H + +namespace pernix::internal { +struct CpuFeatures { + bool avx2 = false; + bool bmi2 = false; + bool avx512f = false; + bool avx512dq = false; + bool avx512bw = false; + bool avx512vl = false; + bool avx512vbmi = false; + bool neon = false; + bool sve = false; + bool sve2 = false; +}; + +CpuFeatures detect_cpu_features(); + +inline CpuFeatures get_cached_cpu_features() { + static const CpuFeatures features = detect_cpu_features(); + return features; +} +} + +#endif //PERNIX_CPU_FEATURES_H diff --git a/src/internal/pernix/dispatch/kernel.h b/src/internal/pernix/dispatch/kernel.h new file mode 100644 index 0000000..1d51c64 --- /dev/null +++ b/src/internal/pernix/dispatch/kernel.h @@ -0,0 +1,27 @@ +#ifndef PERNIX_KERNEL_H +#define PERNIX_KERNEL_H + +#include +#include + +namespace pernix::internal { +using KernelBlockF32Func = int (*)(const void*, float, void*); +using KernelBlocksF32Func = int (*)(const void*, float, void*, unsigned int); +using KernelBlockF64Func = int (*)(const void*, double, void*); +using KernelBlocksF64Func = int (*)(const void*, double, void*, unsigned int); + +template +struct Kernel { + std::string_view name; + FuncType func; + + explicit operator bool() const noexcept { + return func != nullptr; + } + + Kernel(const std::string_view name, FuncType func) : name(name), func(func) { + } +}; +} + +#endif //PERNIX_KERNEL_H diff --git a/src/internal/pernix/dispatch/select.h b/src/internal/pernix/dispatch/select.h new file mode 100644 index 0000000..152117a --- /dev/null +++ b/src/internal/pernix/dispatch/select.h @@ -0,0 +1,159 @@ +#ifndef PERNIX_SELECT_H +#define PERNIX_SELECT_H + +#include +#include + +namespace pernix::internal { +Kernel select_compress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size); + +Kernel select_compress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size); + +Kernel select_compress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size); + +Kernel select_compress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size); + +Kernel select_decompress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_decompress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_decompress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_decompress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + + +Kernel select_auto_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_auto_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_auto_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_auto_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_auto_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_auto_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_auto_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_auto_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + + +Kernel select_fallback_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_fallback_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_fallback_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_fallback_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_fallback_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_fallback_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_fallback_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_fallback_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +#if defined(PERNIX_BUILD_X86_AVX2) + +Kernel select_avx2_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx2_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx2_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx2_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx2_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_avx2_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_avx2_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_avx2_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + +Kernel select_bmi2_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_bmi2_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_bmi2_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_bmi2_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_bmi2_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_bmi2_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_bmi2_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_bmi2_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + +Kernel select_avx512vbmi_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx512vbmi_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx512vbmi_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx512vbmi_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_avx512vbmi_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_avx512vbmi_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_avx512vbmi_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_avx512vbmi_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + +Kernel select_neon_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_neon_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_neon_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_neon_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_neon_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_neon_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_neon_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_neon_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + +Kernel select_sve2_compress_block_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_sve2_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + +Kernel select_sve2_compress_block_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_sve2_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + +Kernel select_sve2_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_sve2_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_sve2_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +Kernel select_sve2_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + +#endif +} + +#endif //PERNIX_SELECT_H diff --git a/include/pernix/fallback/compression.h b/src/internal/pernix/fallback/avx2_compression.h similarity index 73% rename from include/pernix/fallback/compression.h rename to src/internal/pernix/fallback/avx2_compression.h index 5b2780b..c7f03bc 100644 --- a/include/pernix/fallback/compression.h +++ b/src/internal/pernix/fallback/avx2_compression.h @@ -145,9 +145,12 @@ void pack_epi32_fallback(const std::vector& input, uint8_t* __restrict * @param output pointer to the output buffer where compressed bytes will be stored. * @return int status code (0 for success). */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_block_fallback(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { +int compress_block_fallback(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; std::memset(output, 0, BLOCK_SIZE); @@ -173,9 +176,12 @@ int compress_block_fallback(const float_t* __restrict__ input, const float_t sca * @param output pointer to the output buffer where compressed bytes will be stored. * @return int status code (0 for success). */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_block_fallback(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { +int compress_block_fallback(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; std::memset(output, 0, BLOCK_SIZE); @@ -203,9 +209,12 @@ int compress_block_fallback(const double_t* __restrict__ input, const double_t s * @param blocks number of 512-bit blocks to compress. * @return int status code (0 for success). */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_blocks_fallback(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { +int compress_blocks_fallback(const void* __restrict__ input_ptr, float scale, void* __restrict__ output_ptr, uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const float_t* block_input = input; uint8_t* block_output = output; @@ -229,10 +238,13 @@ int compress_blocks_fallback(const float_t* __restrict__ input, const float_t sc * @param blocks number of blocks to compress. * @return int status code (0 for success). */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_blocks_fallback(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { +int compress_blocks_fallback(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, + const unsigned int blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const double_t* block_input = input; uint8_t* block_output = output; @@ -244,63 +256,4 @@ int compress_blocks_fallback(const double_t* __restrict__ input, const double_t return 0; } } // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - */ -int compress_block_fallback(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - */ -int compress_block_fallback_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - */ -int compress_blocks_fallback(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - */ -int compress_blocks_fallback_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - #endif // PERNIX_FALLBACK_COMPRESSION_H diff --git a/src/internal/pernix/fallback/avx2_decompression.h b/src/internal/pernix/fallback/avx2_decompression.h new file mode 100644 index 0000000..f8f4680 --- /dev/null +++ b/src/internal/pernix/fallback/avx2_decompression.h @@ -0,0 +1,248 @@ +#ifndef PERNIX_FALLBACK_DECOMPRESSION_H +#define PERNIX_FALLBACK_DECOMPRESSION_H + +#include + +#include +#include +#include +#include + +namespace pernix { +namespace internal { +/** +* @brief Dequantize a single int32_t value to float using the provided scale. +* +* @param input input int32_t value to be dequantized. +* @param scale scaling factor used during quantization. +* @return float dequantized float value. +*/ +__always_inline float dequantize_epi32(const int32_t input, const float scale) { + return static_cast(input) * scale; +} + +/** +* @brief Dequantize a single int64_t value to double using the provided scale. +* +* @param input input int64_t value to be dequantized. +* @param scale scaling factor used during quantization. +* @return double_t dequantized double value. +*/ +__always_inline double_t dequantize_epi64(const int64_t input, const double_t scale) { + return static_cast(input) * scale; +} + +/** +* @brief Sign-extend a packed integer value stored in the low bits of a 32-bit word. +* +* @tparam BIT_WIDTH number of significant bits in the encoded value. +* @param value unsigned packed value. +* @return int32_t sign-extended value. +*/ +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) +__always_inline auto sign_extend(const uint32_t value) -> int32_t { + if constexpr (BIT_WIDTH == 1) { + return static_cast(value & 1U); + } + + constexpr uint32_t sign_bit = uint32_t{1} << (BIT_WIDTH - 1); + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1; + const uint32_t masked = value & mask; + return static_cast((static_cast(masked ^ sign_bit)) - static_cast(sign_bit)); +} + +/** +* @brief Unpack bit-packed values from a typed input span into signed 32-bit integers. +* +* @tparam T unsigned integer type used to read the source buffer. +* @tparam BIT_WIDTH bit width per packed value. +* @tparam SIGN_VALUES whether to sign-extend unpacked values. +* @param input pointer to the typed packed input buffer. +* @param bit_offset starting bit offset in the first input word. +* @param elements number of values to unpack. +* @return std::vector unpacked values. +*/ +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) +__always_inline auto unpack_epi32_fallback_inner(const uint8_t* __restrict__ input, const uint8_t bit_offset, + const std::size_t elements) + -> std::vector { + constexpr uint32_t bits_in_type = sizeof(T) * 8; + constexpr uint32_t bitmask = BIT_WIDTH == bits_in_type + ? std::numeric_limits::max() + : (1U << BIT_WIDTH) - 1U; + + std::vector output(elements); + + std::size_t idx = 0; + uint8_t bits_in_buffer = 8 - bit_offset; + uint64_t buffer = static_cast(input[idx++]) >> bit_offset; + +#pragma GCC unroll 64 + for (uint32_t i = 0; i < elements; i++) { + while (BIT_WIDTH > bits_in_buffer) { + const auto next_value = static_cast(input[idx++]) << bits_in_buffer; + buffer |= next_value; + bits_in_buffer += 8; + } + + const uint32_t raw_value = static_cast(buffer & bitmask); + if constexpr (SIGN_VALUES) { + output[i] = sign_extend(raw_value); + } else { + output[i] = static_cast(raw_value); + } + + buffer >>= BIT_WIDTH; + bits_in_buffer -= BIT_WIDTH; + } + + return output; +} + +/** +* @brief Unpack packed int32_t values from the input buffer using fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the packed data. +* @param elements number of elements to unpack. +* @return std::vector unpacked int32_t values. +*/ +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) +__always_inline auto unpack_epi32_fallback(const uint8_t* __restrict__ input, + const std::size_t elements) -> std::vector { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return unpack_epi32_fallback_inner(input, 0, elements); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return unpack_epi32_fallback_inner(input, 0, elements); + } else { + return unpack_epi32_fallback_inner(input, 0, elements); + } +} +} // namespace internal + +/** +* @brief Decompress a single 512\-bit block using fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed block. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed float values will be stored. +* @return int status code (0 for success). +*/ +template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) +int decompress_block_fallback(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const std::vector block_values = internal::unpack_epi32_fallback( + input, elements_per_block); + +#pragma GCC unroll 512 + for (uint32_t i = 0; i < elements_per_block; i++) { + output[i] = internal::dequantize_epi32(block_values[i], scale); + } + + return 0; +} + +/** +* @brief Decompress a single block to double values using the fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed block. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed double values will be stored. +* @return int status code (0 for success). +*/ +template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) +int decompress_block_fallback(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const std::vector block_values = internal::unpack_epi32_fallback( + input, elements_per_block); + +#pragma GCC unroll 512 + for (uint32_t i = 0; i < elements_per_block; i++) { + output[i] = internal::dequantize_epi64(block_values[i], scale); + } + + return 0; +} + +/** +* @brief Decompress multiple 512\-bit blocks using fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed data. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed float values will be stored. +* @param blocks number of 512-bit blocks to decompress. +* @return int status code (0 for success). +*/ +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) +int decompress_blocks_fallback(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; block++) { + decompress_block_fallback(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +/** +* @brief Decompress multiple blocks to double values using the fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed data. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed double values will be stored. +* @param blocks number of blocks to decompress. +* @return int status code (0 for success). +*/ +template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) +int decompress_blocks_fallback(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; block++) { + decompress_block_fallback(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} +} // namespace pernix + +#endif // PERNIX_FALLBACK_DECOMPRESSION_H diff --git a/include/pernix/simd_compat.h b/src/internal/pernix/simd_compat.h similarity index 82% rename from include/pernix/simd_compat.h rename to src/internal/pernix/simd_compat.h index 509eb13..b9dac8e 100644 --- a/include/pernix/simd_compat.h +++ b/src/internal/pernix/simd_compat.h @@ -1,6 +1,7 @@ #ifndef PERNIX_SIMD_COMPAT_H #define PERNIX_SIMD_COMPAT_H +#include #include #include @@ -45,14 +46,4 @@ #endif #endif -#ifndef __always_inline -#if defined(__GNUC__) || defined(__clang__) -#define __always_inline inline __attribute__((always_inline)) -#elif defined(_MSC_VER) -#define __always_inline __forceinline -#else -#define __always_inline inline -#endif -#endif - #endif // PERNIX_SIMD_COMPAT_H diff --git a/include/pernix/x86/avx2/compression.h b/src/internal/pernix/x86/avx2/avx2_compression.h similarity index 67% rename from include/pernix/x86/avx2/compression.h rename to src/internal/pernix/x86/avx2/avx2_compression.h index 357947e..3486f02 100644 --- a/include/pernix/x86/avx2/compression.h +++ b/src/internal/pernix/x86/avx2/avx2_compression.h @@ -1,8 +1,8 @@ #ifndef PERNIX_AVX2_COMPRESSION_H #define PERNIX_AVX2_COMPRESSION_H -#include -#include +#include +#include #include #include @@ -17,16 +17,17 @@ template __always_inline __m256i mm256_clamp_signed_epi32(__m256i input) { constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), _mm256_set1_epi32(max_value)); + return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), + _mm256_set1_epi32(max_value)); } /** - * @brief Quantize four float values into signed 32-bit integers. - * - * @param input source float lane values. - * @param scale per-lane scale factor. - * @return __m128i quantized values. - */ +* @brief Quantize four float values into signed 32-bit integers. +* +* @param input source float lane values. +* @param scale per-lane scale factor. +* @return __m128i quantized values. +*/ __always_inline __m128i mm_quantize_ps_epi32(const __m128& input, const __m128& scale) { const __m128 scaled = _mm_mul_ps(input, scale); // const __m128 rounded = _mm_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); @@ -34,24 +35,24 @@ __always_inline __m128i mm_quantize_ps_epi32(const __m128& input, const __m128& } /** - * @brief Quantize two double values into a partially filled 128-bit integer register. - * - * @param input source double lane values. - * @param scale per-lane scale factor. - * @return __m128i quantized values in the low lanes. - */ +* @brief Quantize two double values into a partially filled 128-bit integer register. +* +* @param input source double lane values. +* @param scale per-lane scale factor. +* @return __m128i quantized values in the low lanes. +*/ __always_inline __m128i mm_quantize_pd_epi32(const __m128d& input, const __m128d& scale) { const __m128d scaled = _mm_mul_pd(input, scale); return _mm_cvtpd_epi32(scaled); } /** - * @brief Quantize eight float values into signed 32-bit integers. - * - * @param input source float lane values. - * @param scale per-lane scale factor. - * @return __m256i quantized values. - */ +* @brief Quantize eight float values into signed 32-bit integers. +* +* @param input source float lane values. +* @param scale per-lane scale factor. +* @return __m256i quantized values. +*/ __always_inline __m256i mm256_quantize_ps_epi32(const __m256& input, const __m256& scale) { const __m256 scaled = _mm256_mul_ps(input, scale); // const __m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); @@ -59,12 +60,12 @@ __always_inline __m256i mm256_quantize_ps_epi32(const __m256& input, const __m25 } /** - * @brief Quantize four double values into signed 32-bit integers. - * - * @param input source double lane values. - * @param scale per-lane scale factor. - * @return __m128i quantized values in the low lanes. - */ +* @brief Quantize four double values into signed 32-bit integers. +* +* @param input source double lane values. +* @param scale per-lane scale factor. +* @return __m128i quantized values in the low lanes. +*/ __always_inline __m128i mm256_quantize_pd_epi32(const __m256d& input, const __m256d& scale) { const __m256d scaled = _mm256_mul_pd(input, scale); return _mm256_cvtpd_epi32(scaled); @@ -72,12 +73,12 @@ __always_inline __m128i mm256_quantize_pd_epi32(const __m256d& input, const __m2 #ifndef PERNIX_USE_SIMDE /** - * @brief Emulate per-16-bit left shifts on AVX2. - * - * @param a source values. - * @param count per-lane shift amounts. - * @return __m128i shifted values. - */ +* @brief Emulate per-16-bit left shifts on AVX2. +* +* @param a source values. +* @param count per-lane shift amounts. +* @return __m128i shifted values. +*/ __always_inline static __m128i _mm_sllv_epi16(const __m128i a, const __m128i count) { const __m128i mask = _mm_set1_epi32(0xffff0000); const __m128i low_half = _mm_sllv_epi32(a, _mm_andnot_si128(mask, count)); @@ -86,8 +87,8 @@ __always_inline static __m128i _mm_sllv_epi16(const __m128i a, const __m128i cou } /** - * @brief Emulate per-16-bit right shifts on AVX2. - */ +* @brief Emulate per-16-bit right shifts on AVX2. +*/ __always_inline static __m128i _mm_srlv_epi16(const __m128i a, const __m128i count) { const __m128i mask = _mm_set1_epi32(0x0000ffff); const __m128i low_half = _mm_srlv_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); @@ -96,8 +97,8 @@ __always_inline static __m128i _mm_srlv_epi16(const __m128i a, const __m128i cou } /** - * @brief Emulate per-16-bit left shifts on 256-bit AVX2 vectors. - */ +* @brief Emulate per-16-bit left shifts on 256-bit AVX2 vectors. +*/ __always_inline static __m256i _mm256_sllv_epi16(const __m256i a, const __m256i count) { const __m256i mask = _mm256_set1_epi32(0xffff0000); const __m256i low_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); @@ -106,8 +107,8 @@ __always_inline static __m256i _mm256_sllv_epi16(const __m256i a, const __m256i } /** - * @brief Emulate per-16-bit right shifts on 256-bit AVX2 vectors. - */ +* @brief Emulate per-16-bit right shifts on 256-bit AVX2 vectors. +*/ __always_inline static __m256i _mm256_srlv_epi16(const __m256i a, const __m256i count) { const __m256i mask = _mm256_set1_epi32(0x0000ffff); const __m256i low_half = _mm256_srlv_epi32(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); @@ -116,22 +117,22 @@ __always_inline static __m256i _mm256_srlv_epi16(const __m256i a, const __m256i } /** - * @brief Blend 8-bit lanes by expanding a scalar mask value. - */ +* @brief Blend 8-bit lanes by expanding a scalar mask value. +*/ __always_inline static __m128i mm_blend_epi8(const __m128i X, const __m128i Y, const int8_t M) { return _mm_blendv_epi8(X, Y, _mm_set1_epi8(M)); } /** - * @brief Blend 8-bit lanes in 256-bit vectors by expanding a scalar mask value. - */ +* @brief Blend 8-bit lanes in 256-bit vectors by expanding a scalar mask value. +*/ __always_inline static __m256i mm256_blend_epi8(const __m256i X, const __m256i Y, const int8_t M) { return _mm256_blendv_epi8(X, Y, _mm256_set1_epi8(M)); } /** - * @brief Emulate per-byte left shifts on 128-bit vectors. - */ +* @brief Emulate per-byte left shifts on 128-bit vectors. +*/ __always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { const __m128i mask = _mm_set1_epi16(0xff00); const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); @@ -140,8 +141,8 @@ __always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i coun } /** - * @brief Emulate per-byte right shifts on 128-bit vectors. - */ +* @brief Emulate per-byte right shifts on 128-bit vectors. +*/ __always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { const __m128i mask = _mm_set1_epi16(0x00ff); const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); @@ -150,8 +151,8 @@ __always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i coun } /** - * @brief Emulate per-byte left shifts on 256-bit vectors. - */ +* @brief Emulate per-byte left shifts on 256-bit vectors. +*/ __always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { const __m256i mask = _mm256_set1_epi16(0xff00); const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); @@ -160,8 +161,8 @@ __always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i c } /** - * @brief Emulate per-byte right shifts on 256-bit vectors. - */ +* @brief Emulate per-byte right shifts on 256-bit vectors. +*/ __always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { const __m256i mask = _mm256_set1_epi16(0x00ff); const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); @@ -171,8 +172,8 @@ __always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i c #endif /** - * @brief Pack four 32-bit values for bit widths 1 through 3. - */ +* @brief Pack four 32-bit values for bit widths 1 through 3. +*/ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) __always_inline auto mm_pack_epi32_avx2_1to3(__m128i& input) -> __m128i { @@ -182,15 +183,16 @@ __always_inline auto mm_pack_epi32_avx2_1to3(__m128i& input) -> __m128i { alignas(16) uint32_t lanes[4]; _mm_storeu_si128(reinterpret_cast<__m128i*>(lanes), masked); - const uint32_t packed = (lanes[0] & bitmask) | ((lanes[1] & bitmask) << BIT_WIDTH) | ((lanes[2] & bitmask) << (2 * BIT_WIDTH)) | + const uint32_t packed = (lanes[0] & bitmask) | ((lanes[1] & bitmask) << BIT_WIDTH) | ( + (lanes[2] & bitmask) << (2 * BIT_WIDTH)) | ((lanes[3] & bitmask) << (3 * BIT_WIDTH)); return _mm_cvtsi32_si128(static_cast(packed)); } /** - * @brief Pack eight 32-bit values for bit widths 1 through 3. - */ +* @brief Pack eight 32-bit values for bit widths 1 through 3. +*/ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) __always_inline __m256i mm256_pack_epi32_avx2_1to3(const __m256i& input) { @@ -198,7 +200,8 @@ __always_inline __m256i mm256_pack_epi32_avx2_1to3(const __m256i& input) { const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(static_cast(bitmask))); - const __m256i shifts = _mm256_setr_epi32(0 * BIT_WIDTH, 1 * BIT_WIDTH, 2 * BIT_WIDTH, 3 * BIT_WIDTH, 4 * BIT_WIDTH, 5 * BIT_WIDTH, + const __m256i shifts = _mm256_setr_epi32(0 * BIT_WIDTH, 1 * BIT_WIDTH, 2 * BIT_WIDTH, 3 * BIT_WIDTH, + 4 * BIT_WIDTH, 5 * BIT_WIDTH, 6 * BIT_WIDTH, 7 * BIT_WIDTH); const __m256i shifted = _mm256_sllv_epi32(masked, shifts); @@ -220,15 +223,17 @@ __always_inline __m256i mm256_pack_epi32_avx2_4(const __m256i& input) { const __m256i combined = _mm256_or_si256(packed8, _mm256_srli_epi16(packed8, 4)); - const __m256i shuffled = _mm256_shuffle_epi8(combined, _mm256_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, 0, 2, - 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1)); + const __m256i shuffled = _mm256_shuffle_epi8(combined, _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 2, + 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1)); return shuffled; } /** - * @brief Pack four 32-bit values for bit widths 9 through 16. - */ +* @brief Pack four 32-bit values for bit widths 9 through 16. +*/ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline auto mm_pack_epi32_avx2_9to16(__m128i& input) -> __m128i { @@ -260,8 +265,8 @@ __always_inline auto mm_pack_epi32_avx2_9to16(__m128i& input) -> __m128i { } /** - * @brief Pack eight 32-bit values for bit widths 9 through 16. - */ +* @brief Pack eight 32-bit values for bit widths 9 through 16. +*/ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline auto mm256_pack_epi32_avx2_9to16(const __m256i& input) -> __m256i { @@ -312,8 +317,8 @@ __always_inline auto mm256_pack_epi32_avx2_5to7(const __m256i& input) -> __m256i } /** - * @brief Pack eight 32-bit values for bit widths 17 through 24. - */ +* @brief Pack eight 32-bit values for bit widths 17 through 24. +*/ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline auto mm256_pack_epi32_avx2_17to24(const __m256i& input) -> __m256i { @@ -346,8 +351,8 @@ __always_inline auto mm256_pack_epi32_avx2_17to24(const __m256i& input) -> __m25 } /** - * @brief Pack aligned 8-bit or 16-bit values from four 32-bit lanes. - */ +* @brief Pack aligned 8-bit or 16-bit values from four 32-bit lanes. +*/ template requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) auto mm_pack_aligned_epi32_avx2(__m128i& input) -> __m128i { @@ -359,8 +364,8 @@ auto mm_pack_aligned_epi32_avx2(__m128i& input) -> __m128i { } /** - * @brief Dispatch to the appropriate 128-bit AVX2 packer for the selected bit width. - */ +* @brief Dispatch to the appropriate 128-bit AVX2 packer for the selected bit width. +*/ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 16) auto mm_pack_epi32_avx2(__m128i& input) -> __m128i { @@ -377,23 +382,25 @@ auto mm_pack_epi32_avx2(__m128i& input) -> __m128i { } /** - * @brief Pack aligned 8-bit or 16-bit values from eight 32-bit lanes. - */ +* @brief Pack aligned 8-bit or 16-bit values from eight 32-bit lanes. +*/ template requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) __m256i mm256_pack_aligned_epi32_avx2(const __m256i& input) { if constexpr (BIT_WIDTH == 8) { - const __m128i packed16 = _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1)); - const __m128i packed8 = _mm_packs_epi16(packed16, _mm_setzero_si128()); + const __m128i packed16 = _mm_packs_epi32(_mm256_castsi256_si128(input), + _mm256_extracti128_si256(input, 1)); + const __m128i packed8 = _mm_packs_epi16(packed16, _mm_setzero_si128()); return _mm256_castsi128_si256(packed8); } else { - return _mm256_castsi128_si256(_mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1))); + return _mm256_castsi128_si256( + _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1))); } } /** - * @brief Dispatch to the appropriate 256-bit AVX2 packer for the selected bit width. - */ +* @brief Dispatch to the appropriate 256-bit AVX2 packer for the selected bit width. +*/ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) __m256i mm256_pack_epi32_avx2(const __m256i& input) { @@ -417,21 +424,25 @@ __m256i mm256_pack_epi32_avx2(const __m256i& input) { } // namespace internal /** - * @brief Compress a single block of float using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template +* @brief Compress a single block of float using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_avx2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { +int mm256_compress_block_avx2(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -444,7 +455,7 @@ int mm256_compress_block_avx2(const float_t* __restrict__ input, const float_t s const __m256 source = _mm256_loadu_ps(input); const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); - const __m256i packed = internal::mm256_pack_epi32_avx2(packed_input); + const __m256i packed = internal::mm256_pack_epi32_avx2(packed_input); std::memcpy(output, &packed, BIT_WIDTH); input += 8; @@ -456,7 +467,8 @@ int mm256_compress_block_avx2(const float_t* __restrict__ input, const float_t s #pragma GCC unroll 8 for (uint32_t i = 0; i < remaining; i++) { block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_ps_epi32(input[i], scale))); + static_cast(internal::clamp_signed_quantized( + internal::quantize_ps_epi32(input[i], scale))); } internal::pack_epi32_fallback(block_values, output); @@ -466,21 +478,25 @@ int mm256_compress_block_avx2(const float_t* __restrict__ input, const float_t s } /** - * @brief Compress a single block of double using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template +* @brief Compress a single block of double using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_avx2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { +int mm256_compress_block_avx2(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -496,7 +512,8 @@ int mm256_compress_block_avx2(const double_t* __restrict__ input, const double_t const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); __m256i combined = _mm256_castsi128_si256(quantized1); combined = _mm256_inserti128_si256(combined, quantized2, 1); - const __m256i packed = internal::mm256_pack_epi32_avx2(internal::mm256_clamp_signed_epi32(combined)); + const __m256i packed = internal::mm256_pack_epi32_avx2( + internal::mm256_clamp_signed_epi32(combined)); // _mm_storeu_si128(reinterpret_cast<__m128i*>(output), _mm256_castsi256_si128(packed)); std::memcpy(output, &packed, BIT_WIDTH); input += 8; @@ -508,7 +525,8 @@ int mm256_compress_block_avx2(const double_t* __restrict__ input, const double_t #pragma GCC unroll 8 for (uint32_t i = 0; i < remaining; i++) { block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_pd_epi64(input[i], scale))); + static_cast(internal::clamp_signed_quantized( + internal::quantize_pd_epi64(input[i], scale))); } internal::pack_epi32_fallback(block_values, output); @@ -518,23 +536,27 @@ int mm256_compress_block_avx2(const double_t* __restrict__ input, const double_t } /** - * @brief Compress multiple blocks using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template +* @brief Compress multiple blocks using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_avx2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, +int mm256_compress_blocks_avx2(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const float_t* block_input = input; uint8_t* block_output = output; @@ -548,23 +570,27 @@ int mm256_compress_blocks_avx2(const float_t* __restrict__ input, const float_t } /** - * @brief Compress multiple blocks using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template +* @brief Compress multiple blocks using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_avx2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, +int mm256_compress_blocks_avx2(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const double_t* block_input = input; uint8_t* block_output = output; @@ -578,70 +604,4 @@ int mm256_compress_blocks_avx2(const double_t* __restrict__ input, const double_ } } // namespace pernix -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_compress_block_avx2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_compress_block_f64_avx2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_avx2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_f64_avx2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - #endif // PERNIX_AVX2_COMPRESSION_H diff --git a/include/pernix/x86/avx2/decompression.h b/src/internal/pernix/x86/avx2/avx2_decompression.h similarity index 79% rename from include/pernix/x86/avx2/decompression.h rename to src/internal/pernix/x86/avx2/avx2_decompression.h index 93cbce5..7c530d8 100644 --- a/include/pernix/x86/avx2/decompression.h +++ b/src/internal/pernix/x86/avx2/avx2_decompression.h @@ -1,8 +1,8 @@ #ifndef PERNIX_AVX2_DECOMPRESSION_H #define PERNIX_AVX2_DECOMPRESSION_H -#include -#include +#include +#include #include #include @@ -190,9 +190,12 @@ __m256i mm256_unpack_epi32_avx2(const uint8_t* __restrict__ input) { * * @note This function requires AVX2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_avx2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +int mm256_decompress_block_avx2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -230,9 +233,12 @@ int mm256_decompress_block_avx2(const uint8_t* __restrict__ input, const float_t * * @note This function requires AVX2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_avx2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +int mm256_decompress_block_avx2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -276,10 +282,13 @@ int mm256_decompress_block_avx2(const uint8_t* __restrict__ input, const double_ * * @note This function requires AVX2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_avx2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, +int mm256_decompress_blocks_avx2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const uint8_t* block_input = input; float_t* block_output = output; @@ -306,10 +315,13 @@ int mm256_decompress_blocks_avx2(const uint8_t* __restrict__ input, const float_ * * @note This function requires AVX2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_avx2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, +int mm256_decompress_blocks_avx2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const uint8_t* block_input = input; double_t* block_output = output; @@ -323,70 +335,4 @@ int mm256_decompress_blocks_avx2(const uint8_t* __restrict__ input, const double } } // namespace pernix -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Decompress a single 512-bit block to float using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_block_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block to double using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_block_f64_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_blocks_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks to double using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_blocks_f64_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - #endif // PERNIX_AVX2_DECOMPRESSION_H diff --git a/include/pernix/x86/avx2/tables.h b/src/internal/pernix/x86/avx2/avx2_tables.h similarity index 68% rename from include/pernix/x86/avx2/tables.h rename to src/internal/pernix/x86/avx2/avx2_tables.h index f4f374b..e21af27 100644 --- a/include/pernix/x86/avx2/tables.h +++ b/src/internal/pernix/x86/avx2/avx2_tables.h @@ -7,10 +7,10 @@ #include namespace pernix::internal { -template <__uint8_t BIT_WIDTH, typename T> - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && (std::is_same_v || std::is_same_v)) -struct pack_tables_avx2_16 { - alignas(64) inline static constexpr std::array permute1 = [] { + template<__uint8_t BIT_WIDTH, typename T> + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && (std::is_same_v || std::is_same_v)) + struct pack_tables_avx2_16 { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { return std::array{ @@ -98,10 +98,10 @@ struct pack_tables_avx2_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { return std::array{ @@ -189,10 +189,10 @@ struct pack_tables_avx2_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { return std::array{ @@ -244,10 +244,10 @@ struct pack_tables_avx2_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { return std::array{ @@ -286,10 +286,10 @@ struct pack_tables_avx2_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { return std::array{ @@ -328,10 +328,10 @@ struct pack_tables_avx2_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { return std::array{ @@ -355,71 +355,71 @@ struct pack_tables_avx2_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" __always_inline static T get_permute1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute1.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute1.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_permute2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute2.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute2.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_permute3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute3.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute3.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_shift1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift1.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift1.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_shift2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift2.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift2.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_shift3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift3.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift3.data())); + } + return T{}; } - return T{}; - } #pragma GCC diagnostic pop -}; + }; -template <__uint8_t BIT_WIDTH, typename T> - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) -struct pack_tables_avx2_24 { - alignas(64) inline static constexpr std::array permute1 = [] { + template<__uint8_t BIT_WIDTH, typename T> + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) + struct pack_tables_avx2_24 { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -455,10 +455,10 @@ struct pack_tables_avx2_24 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -494,10 +494,10 @@ struct pack_tables_avx2_24 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -529,10 +529,10 @@ struct pack_tables_avx2_24 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -568,10 +568,10 @@ struct pack_tables_avx2_24 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -607,10 +607,10 @@ struct pack_tables_avx2_24 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -642,71 +642,71 @@ struct pack_tables_avx2_24 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" __always_inline static T get_permute1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute1.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute1.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_permute2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute2.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute2.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_permute3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute3.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute3.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_shift1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift1.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift1.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_shift2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift2.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift2.data())); + } + return T{}; } - return T{}; - } __always_inline static T get_shift3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift3.data())); + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift3.data())); + } + return T{}; } - return T{}; - } #pragma GCC diagnostic pop -}; + }; -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) -struct unpack_tables_avx2 { - alignas(32) inline static constexpr std::array permute = [] { + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) + struct unpack_tables_avx2 { + alignas(32) inline static constexpr std::array permute = [] { // clang-format off if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return std::array{0, -1, -1, -1, 0, 1, -1, -1}; @@ -715,71 +715,72 @@ struct unpack_tables_avx2 { } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { return std::array{0, 1, 2, -1, 2, 3, 4, 5}; } - // clang-format on - }(); - - alignas(32) inline static constexpr std::array shuffle = [] { - std::array shuffles{}; - shuffles.fill(-1); - constexpr std::size_t rebase_second_half = 4 * ((BIT_WIDTH - 1) / 8); - - for (std::size_t lane = 0; lane < 2; ++lane) { - for (std::size_t i = 0; i < 4; ++i) { - const std::size_t value_index = lane * 4 + i; - - const std::size_t bit_start = value_index * BIT_WIDTH; - const std::size_t byte_start = bit_start / 8; - const std::size_t bit_offset = bit_start % 8; - const std::size_t byte_count = (bit_offset + BIT_WIDTH + 7) / 8; - - const std::size_t rebase = (lane == 0) ? 0 : rebase_second_half; - const std::size_t rel_byte_start = byte_start - rebase; - - const std::size_t dst = (lane * 4 + i) * 4; - for (std::size_t k = 0; k < byte_count; ++k) { - shuffles[dst + k] = static_cast(rel_byte_start + k); + // clang-format on + }(); + + alignas(32) inline static constexpr std::array shuffle = [] { + std::array shuffles{}; + shuffles.fill(-1); + constexpr std::size_t rebase_second_half = 4 * ((BIT_WIDTH - 1) / 8); + + for (std::size_t lane = 0; lane < 2; ++lane) { + for (std::size_t i = 0; i < 4; ++i) { + const std::size_t value_index = lane * 4 + i; + + const std::size_t bit_start = value_index * BIT_WIDTH; + const std::size_t byte_start = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t byte_count = (bit_offset + BIT_WIDTH + 7) / 8; + + const std::size_t rebase = (lane == 0) ? 0 : rebase_second_half; + const std::size_t rel_byte_start = byte_start - rebase; + + const std::size_t dst = (lane * 4 + i) * 4; + for (std::size_t k = 0; k < byte_count; ++k) { + shuffles[dst + k] = static_cast(rel_byte_start + k); + } } } - } - return shuffles; - }(); + return shuffles; + }(); - alignas(64) inline static constexpr std::array shift = [] { - std::array shifts{}; + alignas(64) inline static constexpr std::array shift = [] { + std::array shifts{}; - for (std::size_t lane = 0; lane < 8; ++lane) { - const int bit_offset = lane * BIT_WIDTH; - const int bit_in_byte = bit_offset % 8; - const int left_shift = 32 - BIT_WIDTH - bit_in_byte; - shifts[lane] = left_shift; - } + for (std::size_t lane = 0; lane < 8; ++lane) { + const int bit_offset = lane * BIT_WIDTH; + const int bit_in_byte = bit_offset % 8; + const int left_shift = 32 - BIT_WIDTH - bit_in_byte; + shifts[lane] = left_shift; + } - return shifts; - }(); + return shifts; + }(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" - __always_inline static __m256i get_permute() { return _mm256_load_si256(reinterpret_cast(permute.data())); } + __always_inline static __m256i get_permute() { + return _mm256_load_si256(reinterpret_cast(permute.data())); + } __always_inline static T get_shuffle() { - if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shuffle.data())); - } else { - return _mm256_load_si256(reinterpret_cast(shuffle.data())); + if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shuffle.data())); + } else { + return _mm256_load_si256(reinterpret_cast(shuffle.data())); + } } - } __always_inline static T get_shift() { - if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift.data())); - } else { - return _mm256_load_si256(reinterpret_cast(shift.data())); + if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift.data())); + } else { + return _mm256_load_si256(reinterpret_cast(shift.data())); + } } - } #pragma GCC diagnostic pop -}; - -} // namespace pernix::internal + }; +} // namespace pernix::internal #endif // PERNIX_AVX2_TABLES_H diff --git a/include/pernix/x86/avx512vbmi/compression.h b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h similarity index 88% rename from include/pernix/x86/avx512vbmi/compression.h rename to src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h index 9621e12..e19051e 100644 --- a/include/pernix/x86/avx512vbmi/compression.h +++ b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h @@ -2,10 +2,10 @@ #define PERNIX_AVX512VBMI_COMPRESSION_H #include -#include -#include -#include +#include #include +#include +#include #include @@ -99,7 +99,7 @@ static __always_inline __m256i make_m256i_from_4x64(const __m128i a, const __m12 return x; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { @@ -181,7 +181,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { @@ -251,7 +251,7 @@ __always_inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restr return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { @@ -299,7 +299,7 @@ __always_inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __rest return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { @@ -414,7 +414,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restr return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { @@ -496,7 +496,7 @@ __always_inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __rest return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { @@ -560,9 +560,12 @@ __always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __res * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_block_avx512vbmi(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { +int mm512_compress_block_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + std::memset(output, 0, BLOCK_SIZE); if constexpr (BIT_WIDTH <= 8) { @@ -585,9 +588,12 @@ int mm512_compress_block_avx512vbmi(const float_t* __restrict__ input, const flo * * @note This overload is declared for parity with the float path. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_block_avx512vbmi(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { +int mm512_compress_block_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + std::memset(output, 0, BLOCK_SIZE); if constexpr (BIT_WIDTH <= 8) { @@ -612,10 +618,13 @@ int mm512_compress_block_avx512vbmi(const double_t* __restrict__ input, const do * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, +int mm512_compress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const float_t* block_input = input; uint8_t* block_output = output; @@ -638,10 +647,13 @@ int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const fl * @param blocks number of blocks to compress. * @return int status code. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, +int mm512_compress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const double_t* block_input = input; uint8_t* block_output = output; @@ -655,71 +667,4 @@ int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const d } } // namespace pernix -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_block_avx512vbmi(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_block_f64_avx512vbmi(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, - uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_blocks_avx512vbmi(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_blocks_f64_avx512vbmi(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, - uint8_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - #endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h similarity index 87% rename from include/pernix/x86/avx512vbmi/decompression.h rename to src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h index 08abc35..d0f98f0 100644 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h @@ -2,7 +2,7 @@ #define PERNIX_AVX512VBMI_DECOMPRESSION_H #include -#include +#include #include #include #include @@ -26,7 +26,7 @@ __always_inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m51 return _mm512_mul_pd(converted, scale); } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -110,7 +110,7 @@ __always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __rest return 0; } -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -228,7 +228,7 @@ __always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __rest return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -301,7 +301,7 @@ __always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __res return 0; } -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -385,7 +385,7 @@ __always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __res return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -437,7 +437,7 @@ __always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __re return 0; } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { @@ -511,10 +511,13 @@ __always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __re * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return internal::mm512_decompress_block_avx512vbmi_1to8(input, scale, output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { @@ -535,10 +538,13 @@ __always_inline int mm512_decompress_block_avx512vbmi(const uint8_t* __restrict_ * @param output pointer to the output buffer where decompressed double values will be stored. * @return int status code. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return internal::mm512_decompress_block_avx512vbmi_1to8(input, scale, output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { @@ -563,10 +569,13 @@ __always_inline int mm512_decompress_block_avx512vbmi(const uint8_t* __restrict_ * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, +int mm512_decompress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const uint8_t* block_input = input; float_t* block_output = output; @@ -590,10 +599,13 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const * @param blocks number of blocks to decompress. * @return int status code. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, +int mm512_decompress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const uint8_t* block_input = input; double_t* block_output = output; @@ -606,70 +618,4 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const } } // namespace pernix -namespace pernix { -#ifdef __cplusplus -extern "C" { -#endif -/** - * @brief Decompress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_block_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_block_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_blocks_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -#endif -} // namespace pernix - #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/src/internal/pernix/x86/avx512vbmi/compat.h b/src/internal/pernix/x86/avx512vbmi/compat.h new file mode 100644 index 0000000..6918d66 --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/compat.h @@ -0,0 +1,387 @@ +#ifndef PERNIX_AVX512_COMPAT_H +#define PERNIX_AVX512_COMPAT_H + +#include +#include +#include +#include + +namespace pernix::internal { + static __always_inline __mmask8 element_mask8(const uint32_t e) { + return static_cast<__mmask8>(e >= 8 ? 0xFFu : ((1u << e) - 1u)); + } + + static __always_inline __mmask16 element_mask16(const uint32_t e) { + return static_cast<__mmask16>(e >= 16 ? 0xFFFFu : ((1u << e) - 1u)); + } + + static __always_inline __mmask32 element_mask32(const uint32_t e) { + return e >= 32 ? 0xFFFFFFFFu : (1u << e) - 1u; + } + + static __always_inline __mmask64 element_mask64(const uint32_t e) { + return e >= 64 ? 0xFFFFFFFFFFFFFFFFull : (1ull << e) - 1ull; + } + + static __always_inline __m512i mm512_loadu_elements_epi64(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(int64_t)); + return a; +#else + return _mm512_maskz_loadu_epi64(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi64(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(int64_t)); + return a; +#else + return _mm256_maskz_loadu_epi64(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi64(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(int64_t)); + return a; +#else + return _mm_maskz_loadu_epi64(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512i mm512_loadu_elements_epi32(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(int32_t)); + return a; +#else + return _mm512_maskz_loadu_epi32(element_mask16(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi32(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(int32_t)); + return a; +#else + return _mm256_maskz_loadu_epi32(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi32(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(int32_t)); + return a; +#else + return _mm_maskz_loadu_epi32(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512i mm512_loadu_elements_epi16(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(int16_t)); + return a; +#else + return _mm512_maskz_loadu_epi16(element_mask32(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi16(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(int16_t)); + return a; +#else + return _mm256_maskz_loadu_epi16(element_mask16(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi16(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(int16_t)); + return a; +#else + return _mm_maskz_loadu_epi16(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512i mm512_loadu_elements_epi8(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(int8_t)); + return a; +#else + return _mm512_maskz_loadu_epi8(element_mask64(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi8(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(int8_t)); + return a; +#else + return _mm256_maskz_loadu_epi8(element_mask32(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi8(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(int8_t)); + return a; +#else + return _mm_maskz_loadu_epi8(element_mask16(e), mem_addr); +#endif + } + + static __always_inline void mm512_storeu_elements_epi64(void *mem_addr, const uint32_t e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) uint8_t bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); +#else + _mm512_mask_storeu_epi64(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi64(void *mem_addr, const uint32_t e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) uint8_t bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); +#else + _mm256_mask_storeu_epi64(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi64(void *mem_addr, const uint32_t e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) uint8_t bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); +#else + _mm_mask_storeu_epi64(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_epi32(void *mem_addr, const uint32_t e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) uint8_t bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); +#else + _mm512_mask_storeu_epi32(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi32(void *mem_addr, const uint32_t e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) uint8_t bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); +#else + _mm256_mask_storeu_epi32(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi32(void *mem_addr, const uint32_t e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) uint8_t bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); +#else + _mm_mask_storeu_epi32(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_epi16(void *mem_addr, const uint32_t e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) uint8_t bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); +#else + _mm512_mask_storeu_epi16(mem_addr, element_mask32(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi16(void *mem_addr, const uint32_t e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) uint8_t bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); +#else + _mm256_mask_storeu_epi16(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi16(void *mem_addr, const uint32_t e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) uint8_t bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); +#else + _mm_mask_storeu_epi16(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_epi8(void *mem_addr, const uint32_t e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) uint8_t bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); +#else + _mm512_mask_storeu_epi8(mem_addr, element_mask64(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi8(void *mem_addr, const uint32_t e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) uint8_t bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); +#else + _mm256_mask_storeu_epi8(mem_addr, element_mask32(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi8(void *mem_addr, const uint32_t e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) uint8_t bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); +#else + _mm_mask_storeu_epi8(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline __m512 mm512_loadu_elements_ps(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512 a = _mm512_setzero_ps(); + std::memcpy(&a, mem_addr, e * sizeof(float_t)); + return a; +#else + return _mm512_maskz_loadu_ps(element_mask16(e), mem_addr); +#endif + } + + static __always_inline __m256 mm256_loadu_elements_ps(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256 a = _mm256_setzero_ps(); + std::memcpy(&a, mem_addr, e * sizeof(float_t)); + return a; +#else + return _mm256_maskz_loadu_ps(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128 mm_loadu_elements_ps(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128 a = _mm_setzero_ps(); + std::memcpy(&a, mem_addr, e * sizeof(float_t)); + return a; +#else + return _mm_maskz_loadu_ps(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512d mm512_loadu_elements_pd(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512d a = _mm512_setzero_pd(); + std::memcpy(&a, mem_addr, e * sizeof(double_t)); + return a; +#else + return _mm512_maskz_loadu_pd(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m256d mm256_loadu_elements_pd(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256d a = _mm256_setzero_pd(); + std::memcpy(&a, mem_addr, e * sizeof(double_t)); + return a; +#else + return _mm256_maskz_loadu_pd(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128d mm_loadu_elements_pd(const uint32_t e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128d a = _mm_setzero_pd(); + std::memcpy(&a, mem_addr, e * sizeof(double_t)); + return a; +#else + return _mm_maskz_loadu_pd(element_mask8(e), mem_addr); +#endif + } + + static __always_inline void mm512_storeu_elements_ps(void *mem_addr, const uint32_t e, const __m512 a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) float_t values[16]; + _mm512_storeu_ps(values, a); + std::memcpy(mem_addr, values, e * sizeof(float_t)); +#else + _mm512_mask_storeu_ps(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_ps(void *mem_addr, const uint32_t e, const __m256 a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) float_t values[8]; + _mm256_storeu_ps(values, a); + std::memcpy(mem_addr, values, e * sizeof(float_t)); +#else + _mm256_mask_storeu_ps(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_ps(void *mem_addr, const uint32_t e, const __m128 a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) float_t values[4]; + _mm_storeu_ps(values, a); + std::memcpy(mem_addr, values, e * sizeof(float_t)); +#else + _mm_mask_storeu_ps(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_pd(void *mem_addr, const uint32_t e, const __m512d a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) double_t values[8]; + _mm512_storeu_pd(values, a); + std::memcpy(mem_addr, values, e * sizeof(double_t)); +#else + _mm512_mask_storeu_pd(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_pd(void *mem_addr, const uint32_t e, const __m256d a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) double_t values[4]; + _mm256_storeu_pd(values, a); + std::memcpy(mem_addr, values, e * sizeof(double_t)); +#else + _mm256_mask_storeu_pd(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_pd(void *mem_addr, const uint32_t e, const __m128d a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) double_t values[2]; + _mm_storeu_pd(values, a); + std::memcpy(mem_addr, values, e * sizeof(double_t)); +#else + _mm_mask_storeu_pd(mem_addr, element_mask8(e), a); +#endif + } +} + +#endif //PERNIX_AVX512_COMPAT_H diff --git a/src/internal/pernix/x86/avx512vbmi/packing.h b/src/internal/pernix/x86/avx512vbmi/packing.h new file mode 100644 index 0000000..e51c6cc --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/packing.h @@ -0,0 +1,331 @@ +#ifndef PERNIX_AVX512VBMI_PACKING_H +#define PERNIX_AVX512VBMI_PACKING_H + +#include +#include + +namespace pernix::internal { + namespace m128 { + /** + * @brief Pack 8 16-bit values for bit widths 9 through 16 using VBMI. + */ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = pack_tables_avx512_16; + const __m128i maskv = _mm_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i masked = _mm_and_si128(input, maskv); + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m128i permuted1 = _mm_permutexvar_epi16(tables::get_permute1(), masked); + const __m128i permuted2 = _mm_permutexvar_epi16(tables::get_permute2(), masked); + + const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); + const __m128i shifted2 = _mm_srlv_epi16(permuted2, tables::get_shift2()); + + return _mm_or_si128(shifted1, shifted2); + } else { + const auto [mask1, mask2, mask3] = tables::get_permute_masks(); + + const __m128i permuted1 = _mm_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); + const __m128i permuted2 = _mm_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); + const __m128i permuted3 = _mm_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); + + const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi16(permuted3, tables::get_shift3()); + + return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + } + } + + /** + * @brief Pack 16 8-bit values for bit widths 1 through 8 using VBMI. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + const __m128i maskv = _mm_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i masked = _mm_and_si128(input, maskv); + + if constexpr (BIT_WIDTH == 1) { + return _mm_set1_epi16(static_cast(_mm_cmpgt_epi8_mask(masked, _mm_setzero_si128()))); + } else if constexpr (BIT_WIDTH == 2) { + const __m128i shifted = _mm_srli_epi16(masked, 6); + const __m128i combined = _mm_or_si128(masked, shifted); + + const __m128i shifted2 = _mm_srli_epi32(combined, 12); + const __m128i combined2 = _mm_or_si128(shifted2, combined); + + return _mm_cvtepi32_epi8(combined2); + } else if constexpr (BIT_WIDTH == 3) { + const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); + const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); + + const __m128i pair6 = _mm_or_si128(even, _mm_srli_epi16(odd, 5)); + const __m128i packed12 = _mm_or_si128(pair6, _mm_srli_epi32(pair6, 10)); + + return m128::mm_pack_epi16_avx512vbmi_9to16<12>(_mm_cvtepi32_epi16(packed12)); + } else if constexpr (BIT_WIDTH == 4) { + const __m128i shifted = _mm_srli_epi16(masked, 4); + const __m128i combined = _mm_or_si128(masked, shifted); + + return _mm_cvtepi16_epi8(combined); + } else { + const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); + const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); + + const __m128i shifted = _mm_or_si128(even, _mm_srli_epi16(odd, 8 - BIT_WIDTH)); + return mm_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); + } + } + } + + /** + * @brief Pack 4 32-bit values for bit widths 17 through 24 using VBMI. + */ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i &input) { + using tables = pack_tables_avx512_24; + + const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i masked = _mm_and_si128(input, maskv); + + const __m128 permuted1 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute1()); + const __m128 permuted2 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute2()); + const __m128 permuted3 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute3()); + + const __m128i shifted1 = _mm_sllv_epi32(_mm_castps_si128(permuted1), tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi32(_mm_castps_si128(permuted2), tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi32(_mm_castps_si128(permuted3), tables::get_shift3()); + + return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + } // namespace m128 + + namespace m256 { + /** + * @brief Pack 16 16-bit values for bit widths 9 through 16 using VBMI. + */ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = pack_tables_avx512_16; + const __m256i maskv = _mm256_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i masked = _mm256_and_si256(input, maskv); + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m256i permuted1 = _mm256_permutexvar_epi16(tables::get_permute1(), masked); + const __m256i permuted2 = _mm256_permutexvar_epi16(tables::get_permute2(), masked); + + const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_srlv_epi16(permuted2, tables::get_shift2()); + + return _mm256_or_si256(shifted1, shifted2); + } else { + const auto [mask1, mask2, mask3] = tables::get_permute_masks(); + + const __m256i permuted1 = _mm256_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); + const __m256i permuted2 = _mm256_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); + const __m256i permuted3 = _mm256_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); + + const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); + const __m256i shifted3 = _mm256_srlv_epi16(permuted3, tables::get_shift3()); + + return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); + } + } + } + + /** + * @brief Pack 32 8-bit values for bit widths 1 through 8 using VBMI. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + const __m256i maskv = _mm256_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i masked = _mm256_and_si256(input, maskv); + + if constexpr (BIT_WIDTH == 1) { + return _mm256_set1_epi32( + static_cast(_mm256_cmpgt_epi8_mask(masked, _mm256_setzero_si256()))); + } else if constexpr (BIT_WIDTH == 2) { + const __m256i shifted = _mm256_srli_epi16(masked, 6); + const __m256i combined = _mm256_or_si256(masked, shifted); + + const __m256i shifted2 = _mm256_srli_epi32(combined, 12); + const __m256i combined2 = _mm256_or_si256(shifted2, combined); + + return _mm256_castsi128_si256(_mm256_cvtepi32_epi8(combined2)); + } else if constexpr (BIT_WIDTH == 3) { + const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); + const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); + + const __m256i pair6 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 5)); + const __m256i packed12 = _mm256_or_si256(pair6, _mm256_srli_epi32(pair6, 10)); + + return m256::mm256_pack_epi16_avx512vbmi_9to16<12>( + _mm256_castsi128_si256(_mm256_cvtepi32_epi16(packed12))); + } else if constexpr (BIT_WIDTH == 4) { + const __m256i shifted = _mm256_srli_epi16(masked, 4); + const __m256i combined = _mm256_or_si256(masked, shifted); + + return _mm256_castsi128_si256(_mm256_cvtepi16_epi8(combined)); + } else { + const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); + const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); + + const __m256i shifted = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); + return mm256_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); + } + } + } + + /** + * @brief Pack 8 32-bit values for bit widths 17 through 24 using VBMI. + */ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i &input) { + using tables = pack_tables_avx512_24; + + const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i masked = _mm256_and_si256(input, maskv); + + const __m256i permuted1 = _mm256_permutexvar_epi32(tables::get_permute1(), masked); + const __m256i permuted2 = _mm256_permutexvar_epi32(tables::get_permute2(), masked); + const __m256i permuted3 = _mm256_permutexvar_epi32(tables::get_permute3(), masked); + + const __m256i shifted1 = _mm256_sllv_epi32(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi32(permuted2, tables::get_shift2()); + const __m256i shifted3 = _mm256_srlv_epi32(permuted3, tables::get_shift3()); + + return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); + } + } // namespace m256 + + namespace m512 { + /** + * @brief Pack 32 16-bit values for bit widths 9 through 16 using VBMI. + */ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = pack_tables_avx512_16; + const __m512i maskv = _mm512_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i masked = _mm512_and_si512(input, maskv); + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m512i permuted1 = _mm512_permutexvar_epi16(tables::get_permute1(), masked); + const __m512i permuted2 = _mm512_permutexvar_epi16(tables::get_permute2(), masked); + + const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_srlv_epi16(permuted2, tables::get_shift2()); + + return _mm512_or_si512(shifted1, shifted2); + } else { + const auto [mask1, mask2, mask3] = tables::get_permute_masks(); + + const __m512i permuted1 = _mm512_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); + const __m512i permuted2 = _mm512_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); + const __m512i permuted3 = _mm512_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); + + const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); + const __m512i shifted3 = _mm512_srlv_epi16(permuted3, tables::get_shift3()); + + return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); + } + } + } + + /** + * @brief Pack 64 8-bit values for bit widths 1 through 8 using VBMI. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + const __m512i maskv = _mm512_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i masked = _mm512_and_si512(input, maskv); + + if constexpr (BIT_WIDTH == 1) { + return _mm512_set1_epi64( + static_cast(_mm512_cmpgt_epi8_mask(masked, _mm512_setzero_si512()))); + } else if constexpr (BIT_WIDTH == 2) { + const __m512i shifted = _mm512_srli_epi16(masked, 6); + const __m512i combined = _mm512_or_si512(masked, shifted); + + const __m512i shifted2 = _mm512_srli_epi32(combined, 12); + const __m512i combined2 = _mm512_or_si512(shifted2, combined); + + return _mm512_castsi128_si512(_mm512_cvtepi32_epi8(combined2)); + } else if constexpr (BIT_WIDTH == 3) { + const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); + const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); + + const __m512i pair6 = _mm512_or_si512(even, _mm512_srli_epi16(odd, 5)); + const __m512i packed12 = _mm512_or_si512(pair6, _mm512_srli_epi32(pair6, 10)); + + return _mm512_castsi256_si512( + m256::mm256_pack_epi16_avx512vbmi_9to16<12>(_mm512_cvtepi32_epi16(packed12))); + } else if constexpr (BIT_WIDTH == 4) { + const __m512i shifted = _mm512_srli_epi16(masked, 4); + const __m512i combined = _mm512_or_si512(masked, shifted); + + return _mm512_castsi256_si512(_mm512_cvtepi16_epi8(combined)); + } else { + const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); + const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); + + const __m512i shifted = _mm512_or_si512(even, _mm512_srli_epi16(odd, 8 - BIT_WIDTH)); + return mm512_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); + } + } + } + + /** + * @brief Pack 16 32-bit values for bit widths 17 through 24 using VBMI. + */ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i &input) { + using tables = pack_tables_avx512_24; + + const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i masked = _mm512_and_si512(input, maskv); + + const __m512i permuted1 = _mm512_permutexvar_epi32(tables::get_permute1(), masked); + const __m512i permuted2 = _mm512_permutexvar_epi32(tables::get_permute2(), masked); + const __m512i permuted3 = _mm512_permutexvar_epi32(tables::get_permute3(), masked); + + const __m512i shifted1 = _mm512_sllv_epi32(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_sllv_epi32(permuted2, tables::get_shift2()); + const __m512i shifted3 = _mm512_srlv_epi32(permuted3, tables::get_shift3()); + + return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); + } + } // namespace m512 +} // namespace pernix::internal + +#endif // PERNIX_AVX512VBMI_PACKING_H diff --git a/include/pernix/x86/avx512vbmi/tables.h b/src/internal/pernix/x86/avx512vbmi/tables.h similarity index 51% rename from include/pernix/x86/avx512vbmi/tables.h rename to src/internal/pernix/x86/avx512vbmi/tables.h index 4d66727..ab3c665 100644 --- a/include/pernix/x86/avx512vbmi/tables.h +++ b/src/internal/pernix/x86/avx512vbmi/tables.h @@ -9,22 +9,22 @@ #include namespace pernix::internal { -template -static __always_inline Vec load_table(const std::array& table) { - static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); - if constexpr (std::is_same_v) { - return _mm512_load_si512(static_cast(table.data())); - } else if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(table.data())); - } else { - return _mm_load_si128(reinterpret_cast(table.data())); + template + static __always_inline Vec load_table(const std::array &table) { + static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); + if constexpr (std::is_same_v) { + return _mm512_load_si512(static_cast(table.data())); + } else if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(table.data())); + } else { + return _mm_load_si128(reinterpret_cast(table.data())); + } } -} -template - requires(N >= 9 && N <= 15) -struct pack_tables_avx512_16 { - alignas(64) inline static constexpr std::array permute1 = [] { + template + requires(N >= 9 && N <= 15) + struct pack_tables_avx512_16 { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (N == 9) { return std::array{ @@ -112,10 +112,10 @@ struct pack_tables_avx512_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (N == 9) { return std::array{ @@ -202,10 +202,10 @@ struct pack_tables_avx512_16 { 29, 30, -1, -1 }; } - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (N == 9) { return std::array{ @@ -257,10 +257,10 @@ struct pack_tables_avx512_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (N == 9) { return std::array{ @@ -348,10 +348,10 @@ struct pack_tables_avx512_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (N == 9) { return std::array{ @@ -439,10 +439,10 @@ struct pack_tables_avx512_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (N == 9) { return std::array{ @@ -494,10 +494,10 @@ struct pack_tables_avx512_16 { }; } return std::array{}; - // clang-format on - }(); + // clang-format on + }(); - inline static constexpr std::tuple<__mmask32, __mmask32, __mmask32> get_permute_masks() { + inline static constexpr std::tuple<__mmask32, __mmask32, __mmask32> get_permute_masks() { // clang-format off if constexpr (N == 9) { return { @@ -525,296 +525,305 @@ struct pack_tables_avx512_16 { }; } return {0, 0, 0}; - // clang-format on - } - - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } - static __always_inline Vec get_permute3() { return load_table(permute3); } - - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } - static __always_inline Vec get_shift3() { return load_table(shift3); } -}; - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && - (std::is_same_v || std::is_same_v || std::is_same_v)) -struct pack_tables_avx512_24 { -private: - struct word_plan { - int32_t left_index1 = -1; - int32_t left_index2 = -1; - int32_t right_index = -1; - uint32_t left_shift1 = 32; - uint32_t left_shift2 = 32; - uint32_t right_shift = 32; - }; + // clang-format on + } - static constexpr word_plan create_plan(const uint32_t idx) { - word_plan plan{}; + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - const uint32_t word_start = idx * 32u; - const uint32_t word_end = word_start + 32u; + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } + }; - uint32_t left_slot = 0; - for (uint32_t input_lane = 0; input_lane < 16; ++input_lane) { - const uint32_t input_start = input_lane * BIT_WIDTH; - const uint32_t input_end = input_start + BIT_WIDTH; + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && + (std::is_same_v || std::is_same_v || std::is_same_v)) + struct pack_tables_avx512_24 { + private: + struct word_plan { + int32_t left_index1 = -1; + int32_t left_index2 = -1; + int32_t right_index = -1; + uint32_t left_shift1 = 32; + uint32_t left_shift2 = 32; + uint32_t right_shift = 32; + }; + + static constexpr word_plan create_plan(const uint32_t idx) { + word_plan plan{}; + + const uint32_t word_start = idx * 32u; + const uint32_t word_end = word_start + 32u; + + uint32_t left_slot = 0; + for (uint32_t input_lane = 0; input_lane < 16; ++input_lane) { + const uint32_t input_start = input_lane * BIT_WIDTH; + const uint32_t input_end = input_start + BIT_WIDTH; + + const uint32_t overlap_start = std::max(word_start, input_start); + const uint32_t overlap_end = std::min(word_end, input_end); + if (overlap_start >= overlap_end) { + continue; + } - const uint32_t overlap_start = std::max(word_start, input_start); - const uint32_t overlap_end = std::min(word_end, input_end); - if (overlap_start >= overlap_end) { - continue; + const auto output_bit = static_cast(overlap_start - word_start); + const auto input_bit = static_cast(overlap_start - input_start); + const int32_t delta = output_bit - input_bit; + + if (delta >= 0) { + if (left_slot == 0) { + plan.left_index1 = static_cast(input_lane); + plan.left_shift1 = static_cast(delta); + ++left_slot; + } else { + plan.left_index2 = static_cast(input_lane); + plan.left_shift2 = static_cast(delta); + } + } else { + plan.right_index = static_cast(input_lane); + plan.right_shift = static_cast(-delta); + } } - const auto output_bit = static_cast(overlap_start - word_start); - const auto input_bit = static_cast(overlap_start - input_start); - const int32_t delta = output_bit - input_bit; + return plan; + } - if (delta >= 0) { - if (left_slot == 0) { - plan.left_index1 = static_cast(input_lane); - plan.left_shift1 = static_cast(delta); - ++left_slot; - } else { - plan.left_index2 = static_cast(input_lane); - plan.left_shift2 = static_cast(delta); - } - } else { - plan.right_index = static_cast(input_lane); - plan.right_shift = static_cast(-delta); + static constexpr std::array word_plans = [] { + std::array plans{}; + for (uint32_t i = 0; i < 16; ++i) { + plans[i] = create_plan(i); + } + return plans; + }(); + + template + static __always_inline constexpr std::array make_table(Getter getter) { + std::array values{}; + for (uint32_t i = 0; i < 16; ++i) { + values[i] = getter(word_plans[i]); } + return values; } - return plan; - } + alignas(64) static constexpr auto permute1 = make_table([](const word_plan &p) { + return p.left_index1; + }); - static constexpr std::array word_plans = [] { - std::array plans{}; - for (uint32_t i = 0; i < 16; ++i) { - plans[i] = create_plan(i); - } - return plans; - }(); - - template - static __always_inline constexpr std::array make_table(Getter getter) { - std::array values{}; - for (uint32_t i = 0; i < 16; ++i) { - values[i] = getter(word_plans[i]); - } - return values; - } + alignas(64) static constexpr auto permute2 = make_table([](const word_plan &p) { + return p.left_index2; + }); - alignas(64) static constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); + alignas(64) static constexpr auto permute3 = make_table([](const word_plan &p) { + return p.right_index; + }); - alignas(64) static constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); + alignas(64) static constexpr auto shift1 = make_table( + [](const word_plan &p) { return p.left_shift1; }); - alignas(64) static constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); + alignas(64) static constexpr auto shift2 = make_table( + [](const word_plan &p) { return p.left_shift2; }); - alignas(64) static constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); + alignas(64) static constexpr auto shift3 = make_table( + [](const word_plan &p) { return p.right_shift; }); - alignas(64) static constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); + public: + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - alignas(64) static constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } + }; -public: - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } - static __always_inline Vec get_permute3() { return load_table(permute3); } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && + (std::is_same_v || std::is_same_v || std::is_same_v)) + struct unpack_tables_avx512_8 { + private: + alignas(64) inline static constexpr std::array permute1 = [] { + std::array table{}; + std::ranges::fill(table, -1); + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + + table[entry] = static_cast(first_byte); + } - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } - static __always_inline Vec get_shift3() { return load_table(shift3); } -}; + return table; + }(); -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && - (std::is_same_v || std::is_same_v || std::is_same_v)) -struct unpack_tables_avx512_8 { -private: - alignas(64) inline static constexpr std::array permute1 = [] { - std::array table{}; - std::ranges::fill(table, -1); - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; + alignas(64) inline static constexpr std::array permute2 = [] { + std::array table{}; + std::ranges::fill(table, -1); - table[entry] = static_cast(first_byte); - } + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + const size_t bit_offset = bit_start % 8; - return table; - }(); + if (bit_offset + BIT_WIDTH > 8) { + table[entry] = static_cast(first_byte + 1); + } + } - alignas(64) inline static constexpr std::array permute2 = [] { - std::array table{}; - std::ranges::fill(table, -1); + return table; + }(); - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - const size_t bit_offset = bit_start % 8; + alignas(64) inline static constexpr std::array shift1 = [] { + std::array table{}; - if (bit_offset + BIT_WIDTH > 8) { - table[entry] = static_cast(first_byte + 1); - } - } + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8; - return table; - }(); + table[entry] = static_cast(bit_offset); + } - alignas(64) inline static constexpr std::array shift1 = [] { - std::array table{}; + return table; + }(); - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8; + alignas(64) inline static constexpr std::array shift2 = [] { + std::array table{}; - table[entry] = static_cast(bit_offset); - } + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8u; + const size_t spill_bits = (bit_offset + BIT_WIDTH > 8u) ? (bit_offset + BIT_WIDTH - 8u) : 0u; - return table; - }(); + table[entry] = spill_bits ? static_cast(8 - bit_offset) : 0; + } - alignas(64) inline static constexpr std::array shift2 = [] { - std::array table{}; + return table; + }(); - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8u; - const size_t spill_bits = (bit_offset + BIT_WIDTH > 8u) ? (bit_offset + BIT_WIDTH - 8u) : 0u; + public: + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - table[entry] = spill_bits ? static_cast(8 - bit_offset) : 0; - } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + }; - return table; - }(); - -public: - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } - - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } -}; - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && - (std::is_same_v || std::is_same_v || std::is_same_v)) -struct unpack_tables_avx512_16 { -private: - alignas(64) inline static constexpr std::array permute1 = [] { - std::array table{}; - std::ranges::fill(table, -1); - - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - const size_t base = entry * 2; - - table[base] = static_cast(first_byte); - table[base + 1] = static_cast(first_byte + 1); - } + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && + (std::is_same_v || std::is_same_v || std::is_same_v)) + struct unpack_tables_avx512_16 { + private: + alignas(64) inline static constexpr std::array permute1 = [] { + std::array table{}; + std::ranges::fill(table, -1); + + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + const size_t base = entry * 2; + + table[base] = static_cast(first_byte); + table[base + 1] = static_cast(first_byte + 1); + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array permute2 = [] { - std::array table{}; - std::ranges::fill(table, -1); + alignas(64) inline static constexpr std::array permute2 = [] { + std::array table{}; + std::ranges::fill(table, -1); - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - const size_t bit_offset = bit_start % 8; - const size_t base = entry * 2; + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + const size_t bit_offset = bit_start % 8; + const size_t base = entry * 2; - if (bit_offset + BIT_WIDTH > 16) { - table[base] = static_cast(first_byte + 2); + if (bit_offset + BIT_WIDTH > 16) { + table[base] = static_cast(first_byte + 2); + } } - } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift1 = [] { + std::array table{}; - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8u; + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8u; - // Right-shift the 16-bit chunk so the value starts at bit 0. - table[entry] = static_cast(bit_offset); - } + // Right-shift the 16-bit chunk so the value starts at bit 0. + table[entry] = static_cast(bit_offset); + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift2 = [] { - std::array table{}; - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8u; - const size_t spill_bits = (bit_offset + BIT_WIDTH > 16u) ? (bit_offset + BIT_WIDTH - 16u) : 0u; + alignas(64) inline static constexpr std::array shift2 = [] { + std::array table{}; + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8u; + const size_t spill_bits = (bit_offset + BIT_WIDTH > 16u) ? (bit_offset + BIT_WIDTH - 16u) : 0u; - // Move spill bits from byte3 to their final bit positions before merge. - table[entry] = spill_bits ? static_cast(16u - bit_offset) : 0; - } + // Move spill bits from byte3 to their final bit positions before merge. + table[entry] = spill_bits ? static_cast(16u - bit_offset) : 0; + } - return table; - }(); + return table; + }(); -public: - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } + public: + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } -}; + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + }; -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && - (std::is_same_v || std::is_same_v || std::is_same_v)) -struct unpack_tables_avx512_24 { -private: - alignas(64) inline static constexpr std::array permute = [] { - std::array table{}; - std::ranges::fill(table, -1); + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && + (std::is_same_v || std::is_same_v || std::is_same_v)) + struct unpack_tables_avx512_24 { + private: + alignas(64) inline static constexpr std::array permute = [] { + std::array table{}; + std::ranges::fill(table, -1); - for (size_t entry = 0; entry < 16; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_end = bit_start + BIT_WIDTH - 1; + for (size_t entry = 0; entry < 16; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_end = bit_start + BIT_WIDTH - 1; - const size_t first_byte = bit_start / 8; - const size_t last_byte = bit_end / 8; + const size_t first_byte = bit_start / 8; + const size_t last_byte = bit_end / 8; - const size_t base = entry * 4; + const size_t base = entry * 4; - for (size_t byte = first_byte; byte <= last_byte; ++byte) { - table[base + (byte - first_byte)] = static_cast(byte); + for (size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); + } } - } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift = [] { + std::array table{}; - for (size_t entry = 0; entry < 16; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - table[entry] = static_cast(32u - BIT_WIDTH - (bit_start % 8u)); - } + for (size_t entry = 0; entry < 16; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + table[entry] = static_cast(32u - BIT_WIDTH - (bit_start % 8u)); + } - return table; - }(); + return table; + }(); -public: - static __always_inline Vec get_permute() { return load_table(permute); } - static __always_inline Vec get_shift() { return load_table(shift); } -}; + public: + static __always_inline Vec get_permute() { return load_table(permute); } + static __always_inline Vec get_shift() { return load_table(shift); } + }; } // namespace pernix::internal #endif // PERNIX_AVX512VBMI_TABLES_H diff --git a/src/internal/pernix/x86/avx512vbmi/unpacking.h b/src/internal/pernix/x86/avx512vbmi/unpacking.h new file mode 100644 index 0000000..6799cd2 --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/unpacking.h @@ -0,0 +1,500 @@ +#ifndef PERNIX_AVX512VBMI_UNPACKING_H +#define PERNIX_AVX512VBMI_UNPACKING_H + +#include +#include + +namespace pernix::internal { + namespace m128 { + constexpr __mmask16 kAlternateByteMask16 = 0xAAAAULL; + +__always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi16(0x00ff); + const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); + const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); + return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); + } + +__always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi16(0xff00); + const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); + const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); + return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); + } + +__always_inline static __m128i _mm_slli_epi8(const __m128i a, const int8_t imm8) { + return _mm_sllv_epi8(a, _mm_set1_epi8(imm8)); + } + +__always_inline static __m128i _mm_srli_epi8(const __m128i a, const int imm8) { + const __m128i lo_mask = _mm_set1_epi16(0x00ff); + const __m128i hi_mask = _mm_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m128i lo = _mm_srl_epi16(_mm_and_si128(a, lo_mask), shift); + const __m128i hi = _mm_and_si128(_mm_srl_epi16(a, shift), hi_mask); + + return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); + } + +__always_inline static __m128i _mm_srai_epi8(const __m128i a, const int8_t imm8) { + const __m128i lo_mask = _mm_set1_epi16(0x00ff); + const __m128i hi_mask = _mm_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m128i hi = _mm_and_si128(_mm_sra_epi16(a, shift), hi_mask); + + const __m128i lo_as_hi = _mm_slli_epi16(_mm_and_si128(a, lo_mask), 8); + const __m128i lo = _mm_and_si128(_mm_srli_epi16(_mm_sra_epi16(lo_as_hi, shift), 8), lo_mask); + + return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + if constexpr (BIT_WIDTH == 1) { + const auto value = static_cast<__mmask16>(_mm_cvtsi128_si64(input)); + const __m128i source = _mm_movm_epi8(value); + const __m128i unpacked = _mm_abs_epi8(source); + return unpacked; + } else if constexpr (BIT_WIDTH == 2) { + __m128i values_shift0 = input; + __m128i values_shift2 = _mm_srli_epi16(values_shift0, 2); + const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); + const __m128i values_shift6 = _mm_srli_epi16(values_shift0, 6); + + __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift2); + values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift2); + values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); + + interleave_tmp = _mm_unpacklo_epi8(values_shift4, values_shift6); + values_shift2 = _mm_unpackhi_epi8(values_shift4, values_shift6); + values_shift2 = _mm_unpacklo_epi64(interleave_tmp, values_shift2); + + interleave_tmp = _mm_unpacklo_epi16(values_shift0, values_shift2); + values_shift0 = _mm_unpackhi_epi16(values_shift0, values_shift2); + values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); + values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); + + values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0303)); + + return values_shift0; + } else if constexpr (BIT_WIDTH == 4) { + __m128i values_shift0 = input; + const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); + + const __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift4); + values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift4); + values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); + values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); + + values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0F0F)); + + return values_shift0; + } else { + using tables = unpack_tables_avx512_8; + + const __m128i permuted1 = _mm_permutexvar_epi8(tables::get_permute1(), input); + const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); + + const __m128i shifted1 = _mm_srlv_epi8(permuted1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi8(permuted2, tables::get_shift2()); + + const __mmask16 spill_mask = _mm_cmpneq_epi8_mask(tables::get_shift2(), _mm_setzero_si128()); + __m128i combined = _mm_or_si128(shifted1, _mm_maskz_mov_epi8(spill_mask, shifted2)); + + constexpr uint32_t shift = 8 - BIT_WIDTH; + combined = _mm_slli_epi8(combined, shift); + if (SIGN_VALUES) { + combined = _mm_srai_epi8(combined, shift); + } else { + combined = _mm_srli_epi8(combined, shift); + } + + return combined; + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = unpack_tables_avx512_16; + + const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute1(), input); + + __m128i shifted = _mm_srlv_epi16(permuted, tables::get_shift1()); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); + const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); + shifted = _mm_or_si128(shifted, shifted2); + } + + constexpr uint32_t shift = 16 - BIT_WIDTH; + shifted = _mm_slli_epi16(shifted, shift); + if (SIGN_VALUES) { + shifted = _mm_srai_epi16(shifted, shift); + } else { + shifted = _mm_srli_epi16(shifted, shift); + } + + return shifted; + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m128i mm_unpack_epi32_avx512vbmi_17to24(const __m128i &input) { + using tables = unpack_tables_avx512_24; + + const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute(), input); + + constexpr uint32_t shift = 32 - BIT_WIDTH; + __m128i shifted = _mm_sllv_epi32(permuted, tables::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm_srai_epi32(shifted, shift); + } else { + shifted = _mm_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace m128 + + namespace m256 { + constexpr __mmask32 kAlternateByteMask32 = 0xAAAAAAAAULL; + +__always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi16(0x00ff); + const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); + const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); + return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); + } + +__always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi16(0xff00); + const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); + const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); + return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); + } + +__always_inline static __m256i _mm256_slli_epi8(const __m256i a, const int8_t imm8) { + return _mm256_sllv_epi8(a, _mm256_set1_epi8(imm8)); + } + +__always_inline static __m256i _mm256_srli_epi8(const __m256i a, const int8_t imm8) { + const __m256i lo_mask = _mm256_set1_epi16(0x00ff); + const __m256i hi_mask = _mm256_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m256i lo = _mm256_srl_epi16(_mm256_and_si256(a, lo_mask), shift); + const __m256i hi = _mm256_and_si256(_mm256_srl_epi16(a, shift), hi_mask); + + return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); + } + +__always_inline static __m256i _mm256_srai_epi8(const __m256i a, const int8_t imm8) { + const __m256i lo_mask = _mm256_set1_epi16(0x00ff); + const __m256i hi_mask = _mm256_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m256i hi = _mm256_and_si256(_mm256_sra_epi16(a, shift), hi_mask); + + const __m256i lo_as_hi = _mm256_slli_epi16(_mm256_and_si256(a, lo_mask), 8); + const __m256i lo = _mm256_and_si256(_mm256_srli_epi16(_mm256_sra_epi16(lo_as_hi, shift), 8), lo_mask); + + return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + if constexpr (BIT_WIDTH == 1) { + const auto value = static_cast<__mmask32>(_mm_cvtsi128_si64(_mm256_castsi256_si128(input))); + const __m256i source = _mm256_movm_epi8(value); + const __m256i unpacked = _mm256_abs_epi8(source); + return unpacked; + } else if constexpr (BIT_WIDTH == 2) { + __m256i values_shift0 = input; + __m256i values_shift2 = _mm256_srli_epi16(values_shift0, 2); + const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); + const __m256i values_shift6 = _mm256_srli_epi16(values_shift0, 6); + + __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift2); + values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift2); + values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); + + interleave_tmp = _mm256_unpacklo_epi8(values_shift4, values_shift6); + values_shift2 = _mm256_unpackhi_epi8(values_shift4, values_shift6); + values_shift2 = _mm256_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); + + interleave_tmp = _mm256_unpacklo_epi16(values_shift0, values_shift2); + values_shift0 = _mm256_unpackhi_epi16(values_shift0, values_shift2); + values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); + values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); + + values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0303)); + + return values_shift0; + } else if constexpr (BIT_WIDTH == 4) { + __m256i values_shift0 = input; + const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); + + __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift4); + values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift4); + values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); + values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); + + values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0F0F)); + + return values_shift0; + } else { + using tables = unpack_tables_avx512_8; + + const __m256i permuted1 = _mm256_permutexvar_epi8(tables::get_permute1(), input); + const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); + + const __m256i shifted1 = _mm256_srlv_epi8(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi8(permuted2, tables::get_shift2()); + + const __mmask32 spill_mask = _mm256_cmpneq_epi8_mask(tables::get_shift2(), _mm256_setzero_si256()); + __m256i combined = _mm256_or_si256(shifted1, _mm256_maskz_mov_epi8(spill_mask, shifted2)); + + constexpr uint32_t shift = 8 - BIT_WIDTH; + combined = _mm256_slli_epi8(combined, shift); + if (SIGN_VALUES) { + combined = _mm256_srai_epi8(combined, shift); + } else { + combined = _mm256_srli_epi8(combined, shift); + } + + return combined; + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = unpack_tables_avx512_16; + + const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute1(), input); + + __m256i shifted = _mm256_srlv_epi16(permuted, tables::get_shift1()); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); + const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); + shifted = _mm256_or_si256(shifted, shifted2); + } + + constexpr uint32_t shift = 16 - BIT_WIDTH; + shifted = _mm256_slli_epi16(shifted, shift); + if (SIGN_VALUES) { + shifted = _mm256_srai_epi16(shifted, shift); + } else { + shifted = _mm256_srli_epi16(shifted, shift); + } + + return shifted; + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m256i mm256_unpack_epi32_avx512vbmi_17to24(const __m256i &input) { + using tables = unpack_tables_avx512_24; + + const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute(), input); + + constexpr uint32_t shift = 32 - BIT_WIDTH; + __m256i shifted = _mm256_sllv_epi32(permuted, tables::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm256_srai_epi32(shifted, shift); + } else { + shifted = _mm256_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace m256 + + namespace m512 { + constexpr __mmask64 kAlternateByteMask64 = 0xAAAAAAAAAAAAAAAAULL; + +__always_inline static __m512i _mm512_srlv_epi8(const __m512i a, const __m512i count) { + const __m512i mask = _mm512_set1_epi16(0x00ff); + const __m512i low_half = _mm512_srlv_epi16(_mm512_and_si512(mask, a), _mm512_and_si512(mask, count)); + const __m512i high_half = _mm512_srlv_epi16(a, _mm512_srli_epi16(count, 8)); + return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); + } + +__always_inline static __m512i _mm512_sllv_epi8(const __m512i a, const __m512i count) { + const __m512i mask = _mm512_set1_epi16(0xff00); + const __m512i low_half = _mm512_sllv_epi16(a, _mm512_andnot_si512(mask, count)); + const __m512i high_half = _mm512_sllv_epi16(_mm512_and_si512(mask, a), _mm512_srli_epi16(count, 8)); + return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); + } + +__always_inline static __m512i _mm512_slli_epi8(const __m512i a, const int8_t imm8) { + return _mm512_sllv_epi8(a, _mm512_set1_epi8(imm8)); + } + +__always_inline static __m512i _mm512_srli_epi8(const __m512i a, const int8_t imm8) { + const __m512i lo_mask = _mm512_set1_epi16(0x00ff); + const __m512i hi_mask = _mm512_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m512i lo = _mm512_srl_epi16(_mm512_and_si512(a, lo_mask), shift); + const __m512i hi = _mm512_and_si512(_mm512_srl_epi16(a, shift), hi_mask); + + return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); + } + +__always_inline static __m512i _mm512_srai_epi8(const __m512i a, const int8_t imm8) { + const __m512i lo_mask = _mm512_set1_epi16(0x00ff); + const __m512i hi_mask = _mm512_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m512i hi = _mm512_and_si512(_mm512_sra_epi16(a, shift), hi_mask); + + const __m512i lo_as_hi = _mm512_slli_epi16(_mm512_and_si512(a, lo_mask), 8); + const __m512i lo = _mm512_and_si512(_mm512_srli_epi16(_mm512_sra_epi16(lo_as_hi, shift), 8), lo_mask); + + return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + if constexpr (BIT_WIDTH == 1) { + const auto value = static_cast<__mmask64>(_mm_cvtsi128_si64(_mm512_castsi512_si128(input))); + const __m512i source = _mm512_movm_epi8(value); + const __m512i unpacked = _mm512_abs_epi8(source); + return unpacked; + } else if constexpr (BIT_WIDTH == 2) { + __m512i values_shift0 = input; + __m512i values_shift2 = _mm512_srli_epi16(values_shift0, 2); + const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); + const __m512i values_shift6 = _mm512_srli_epi16(values_shift0, 6); + + __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift2); + values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift2); + values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); + + interleave_tmp = _mm512_unpacklo_epi8(values_shift4, values_shift6); + values_shift2 = _mm512_unpackhi_epi8(values_shift4, values_shift6); + values_shift2 = _mm512_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); + + interleave_tmp = _mm512_unpacklo_epi16(values_shift0, values_shift2); + values_shift0 = _mm512_unpackhi_epi16(values_shift0, values_shift2); + values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x88); + values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); + + values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0303)); + + return values_shift0; + } else if constexpr (BIT_WIDTH == 4) { + __m512i values_shift0 = input; + const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); + + __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift4); + values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift4); + values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x44); + values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); + + values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0F0F)); + + return values_shift0; + } else { + using tables = unpack_tables_avx512_8; + + const __m512i permuted1 = _mm512_permutexvar_epi8(tables::get_permute1(), input); + const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); + + const __m512i shifted1 = _mm512_srlv_epi8(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_sllv_epi8(permuted2, tables::get_shift2()); + + const __mmask64 spill_mask = _mm512_cmpneq_epi8_mask(tables::get_shift2(), _mm512_setzero_si512()); + __m512i combined = _mm512_or_si512(shifted1, _mm512_maskz_mov_epi8(spill_mask, shifted2)); + + constexpr uint32_t shift = 8 - BIT_WIDTH; + combined = _mm512_slli_epi8(combined, shift); + if (SIGN_VALUES) { + combined = _mm512_srai_epi8(combined, shift); + } else { + combined = _mm512_srli_epi8(combined, shift); + } + + return combined; + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = unpack_tables_avx512_16; + + const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute1(), input); + __m512i shifted = _mm512_srlv_epi16(permuted, tables::get_shift1()); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); + const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); + shifted = _mm512_or_si512(shifted, shifted2); + } + + constexpr uint32_t shift = 16 - BIT_WIDTH; + shifted = _mm512_slli_epi16(shifted, shift); + if (SIGN_VALUES) { + shifted = _mm512_srai_epi16(shifted, shift); + } else { + shifted = _mm512_srli_epi16(shifted, shift); + } + + return shifted; + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m512i mm512_unpack_epi32_avx512vbmi_17to24(const __m512i &input) { + using tables = unpack_tables_avx512_24; + + const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute(), input); + __m512i shifted = _mm512_sllv_epi32(permuted, tables::get_shift()); + + constexpr uint32_t shift = 32 - BIT_WIDTH; + if constexpr (SIGN_VALUES) { + shifted = _mm512_srai_epi32(shifted, shift); + } else { + shifted = _mm512_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace m512 +} // namespace pernix::internal + +#endif // PERNIX_AVX512VBMI_UNPACKING_H diff --git a/include/pernix/x86/bmi2/compression.h b/src/internal/pernix/x86/bmi2/bmi2_compression.h similarity index 56% rename from include/pernix/x86/bmi2/compression.h rename to src/internal/pernix/x86/bmi2/bmi2_compression.h index ca4b976..2f2d362 100644 --- a/include/pernix/x86/bmi2/compression.h +++ b/src/internal/pernix/x86/bmi2/bmi2_compression.h @@ -1,8 +1,8 @@ #ifndef PERNIX_BMI2_COMPRESSION_H #define PERNIX_BMI2_COMPRESSION_H -#include -#include +#include +#include #include #include @@ -12,11 +12,11 @@ namespace pernix { namespace internal { /** - * @brief Build the masks and shift constants used by the BMI2 packers. - * - * @tparam BIT_WIDTH bit width per packed value. - * @return std::tuple mask tuple used by the BMI2 helpers. - */ +* @brief Build the masks and shift constants used by the BMI2 packers. +* +* @tparam BIT_WIDTH bit width per packed value. +* @return std::tuple mask tuple used by the BMI2 helpers. +*/ template requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) static constexpr std::tuple pack_avx2_bmi2_constants() { @@ -42,12 +42,12 @@ static constexpr std::tuple pack_avx2_bm } /** - * @brief Pack four 32-bit values with BMI2 extract instructions. - * - * @tparam BIT_WIDTH bit width per packed value. - * @param input SIMD register containing four quantized values. - * @return __m128i packed bitstream in the low bytes of the result. - */ +* @brief Pack four 32-bit values with BMI2 extract instructions. +* +* @tparam BIT_WIDTH bit width per packed value. +* @param input SIMD register containing four quantized values. +* @return __m128i packed bitstream in the low bytes of the result. +*/ template requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) static inline auto mm_pack_epi32_bmi2(const __m128i& input) -> __m128i { @@ -73,12 +73,12 @@ static inline auto mm_pack_epi32_bmi2(const __m128i& input) -> __m128i { } /** - * @brief Pack eight 32-bit values with BMI2 extract instructions. - * - * @tparam BIT_WIDTH bit width per packed value. - * @param input SIMD register containing eight quantized values. - * @return __m256i packed bitstream in the low bytes of the result. - */ +* @brief Pack eight 32-bit values with BMI2 extract instructions. +* +* @tparam BIT_WIDTH bit width per packed value. +* @param input SIMD register containing eight quantized values. +* @return __m256i packed bitstream in the low bytes of the result. +*/ template requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) static inline auto mm256_pack_epi32_bmi2(const __m256i& input) -> __m256i { @@ -144,26 +144,31 @@ static inline auto mm256_pack_epi32_bmi2(const __m256i& input) -> __m256i { append_bits(x2, 2 * chunk_bits); append_bits(x3, 3 * chunk_bits); - return _mm256_setr_epi64x(static_cast(out0), static_cast(out1), static_cast(out2), 0); + return _mm256_setr_epi64x(static_cast(out0), static_cast(out1), + static_cast(out2), 0); } } } // namespace internal /** - * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template +* @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_bmi2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { +int mm256_compress_block_bmi2(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -176,7 +181,7 @@ int mm256_compress_block_bmi2(const float_t* __restrict__ input, const float_t s const __m256 source = _mm256_loadu_ps(input); const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); - const __m256i packed = internal::mm256_pack_epi32_bmi2(packed_input); + const __m256i packed = internal::mm256_pack_epi32_bmi2(packed_input); std::memcpy(output, &packed, BIT_WIDTH); input += 8; output += BIT_WIDTH; @@ -187,7 +192,8 @@ int mm256_compress_block_bmi2(const float_t* __restrict__ input, const float_t s #pragma GCC unroll 8 for (uint32_t i = 0; i < remaining; i++) { block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_ps_epi32(input[i], scale))); + static_cast(internal::clamp_signed_quantized( + internal::quantize_ps_epi32(input[i], scale))); } internal::pack_epi32_fallback(block_values, output); @@ -197,19 +203,23 @@ int mm256_compress_block_bmi2(const float_t* __restrict__ input, const float_t s } /** - * @brief Compress a single block of double values using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template +* @brief Compress a single block of double values using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_bmi2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { +int mm256_compress_block_bmi2(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -225,7 +235,8 @@ int mm256_compress_block_bmi2(const double_t* __restrict__ input, const double_t const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); __m256i combined = _mm256_castsi128_si256(quantized1); combined = _mm256_inserti128_si256(combined, quantized2, 1); - const __m256i packed = internal::mm256_pack_epi32_bmi2(internal::mm256_clamp_signed_epi32(combined)); + const __m256i packed = internal::mm256_pack_epi32_bmi2( + internal::mm256_clamp_signed_epi32(combined)); std::memcpy(output, &packed, BIT_WIDTH); input += 8; output += BIT_WIDTH; @@ -236,7 +247,8 @@ int mm256_compress_block_bmi2(const double_t* __restrict__ input, const double_t #pragma GCC unroll 8 for (uint32_t i = 0; i < remaining; i++) { block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_pd_epi64(input[i], scale))); + static_cast(internal::clamp_signed_quantized( + internal::quantize_pd_epi64(input[i], scale))); } internal::pack_epi32_fallback(block_values, output); @@ -245,22 +257,25 @@ int mm256_compress_block_bmi2(const double_t* __restrict__ input, const double_t } /** - * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template +* @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of 512-bit blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_bmi2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { +int mm256_compress_blocks_bmi2(const void* __restrict__ input_ptr, const float_t scale, + void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const float_t* block_input = input; uint8_t* block_output = output; @@ -274,21 +289,24 @@ int mm256_compress_blocks_bmi2(const float_t* __restrict__ input, const float_t } /** - * @brief Compress multiple blocks of double values using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template +* @brief Compress multiple blocks of double values using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_bmi2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { +int mm256_compress_blocks_bmi2(const void* __restrict__ input_ptr, const double_t scale, + void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const double_t* block_input = input; uint8_t* block_output = output; @@ -302,70 +320,4 @@ int mm256_compress_blocks_bmi2(const double_t* __restrict__ input, const double_ } } // namespace pernix -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_block_bmi2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_block_f64_bmi2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_bmi2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_f64_bmi2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - #endif // PERNIX_BMI2_COMPRESSION_H diff --git a/include/pernix/x86/bmi2/decompression.h b/src/internal/pernix/x86/bmi2/bmi2_decompression.h similarity index 80% rename from include/pernix/x86/bmi2/decompression.h rename to src/internal/pernix/x86/bmi2/bmi2_decompression.h index 443d673..414d6e0 100644 --- a/include/pernix/x86/bmi2/decompression.h +++ b/src/internal/pernix/x86/bmi2/bmi2_decompression.h @@ -1,7 +1,7 @@ #ifndef PERNIX_BMI2_DECOMPRESSION_H #define PERNIX_BMI2_DECOMPRESSION_H -#include +#include #include #include @@ -205,9 +205,12 @@ __m256i mm256_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { * * @note This function requires AVX2 and BMI2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_bmi2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { +int mm256_decompress_block_bmi2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -246,9 +249,12 @@ int mm256_decompress_block_bmi2(const uint8_t* __restrict__ input, const float_t * * @note This function requires AVX2 and BMI2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_bmi2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { +int mm256_decompress_block_bmi2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; @@ -294,10 +300,13 @@ int mm256_decompress_block_bmi2(const uint8_t* __restrict__ input, const double_ * * @note This function requires AVX2 and BMI2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_bmi2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, +int mm256_decompress_blocks_bmi2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const uint8_t* block_input = input; float_t* block_output = output; @@ -325,10 +334,13 @@ int mm256_decompress_blocks_bmi2(const uint8_t* __restrict__ input, const float_ * * @note This function requires AVX2 and BMI2 support. */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_bmi2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, +int mm256_decompress_blocks_bmi2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, const uint32_t blocks) { + const auto* input = static_cast(input_ptr); + auto* output = static_cast(output_ptr); + const uint8_t* block_input = input; double_t* block_output = output; @@ -342,70 +354,4 @@ int mm256_decompress_blocks_bmi2(const uint8_t* __restrict__ input, const double } } // namespace pernix -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Decompress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_block_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_block_f64_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_blocks_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_blocks_f64_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - #endif // PERNIX_BMI2_DECOMPRESSION_H diff --git a/include/pernix/x86/utils.h b/src/internal/pernix/x86/utils.h similarity index 100% rename from include/pernix/x86/utils.h rename to src/internal/pernix/x86/utils.h diff --git a/src/pernix.cpp b/src/pernix.cpp index 94d9d14..8b806a6 100644 --- a/src/pernix.cpp +++ b/src/pernix.cpp @@ -1,232 +1,165 @@ #include +#include -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -// Use the best available implementation based on detected CPU features at compile time -#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) -#ifdef PERNIX_AVX512_VBMI_ENABLED -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_avx512vbmi(bit_width, input, scale, output); +namespace { +bool is_valid_block_size(uint32_t block_size) { + return block_size == 64 || block_size == 128 || block_size == 256 || block_size == 512 || block_size == 1024; } - -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_f64_avx512vbmi(bit_width, input, scale, output); } -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm512_compress_blocks_avx512vbmi(bit_width, input, scale, output, blocks); -} +extern "C" { +pernix_status pernix_compress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm512_compress_blocks_f64_avx512vbmi(bit_width, input, scale, output, blocks); -} + const auto kernel = pernix::internal::select_compress_block_f32(static_cast(backend), bit_width, block_size); -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm512_decompress_block_avx512vbmi(bit_width, input, scale, output); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm512_decompress_block_f64_avx512vbmi(bit_width, input, scale, output); + return static_cast(kernel.func(input, scale, output)); } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return mm512_decompress_blocks_avx512vbmi(bit_width, input, scale, output, blocks); -} +pernix_status pernix_compress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output, uint32_t blocks) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return mm512_decompress_blocks_f64_avx512vbmi(bit_width, input, scale, output, blocks); -} -#else -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_avx2(bit_width, input, scale, output); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_f64_avx2(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_compress_blocks_f32(static_cast(backend), bit_width, block_size); -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm256_compress_blocks_avx2(bit_width, input, scale, output, blocks); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm256_compress_blocks_f64_avx2(bit_width, input, scale, output, blocks); + return static_cast(kernel.func(input, scale, output, blocks)); } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm256_decompress_block_avx2(bit_width, input, scale, output); -} +pernix_status pernix_decompress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm256_decompress_block_f64_avx2(bit_width, input, scale, output); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return mm256_decompress_blocks_avx2(bit_width, input, scale, output, blocks); -} + const auto kernel = pernix::internal::select_decompress_block_f32(static_cast(backend), bit_width, block_size, + sign_values); -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return mm256_decompress_blocks_f64_avx2(bit_width, input, scale, output, blocks); -} -#endif -#elif defined(PERNIX_BACKEND_ARM64_NEON) -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return neon_compress_block(bit_width, input, scale, output); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return neon_compress_block_f64(bit_width, input, scale, output); + return static_cast(kernel.func(input, scale, output)); } -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return neon_compress_blocks(bit_width, input, scale, output, blocks); -} +pernix_status pernix_decompress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + float scale, void* output, uint32_t blocks, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return neon_compress_blocks_f64(bit_width, input, scale, output, blocks); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return neon_decompress_block(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_decompress_blocks_f32(static_cast(backend), bit_width, block_size, + sign_values); -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return neon_decompress_block_f64(bit_width, input, scale, output); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return neon_decompress_blocks(bit_width, input, scale, output, blocks); + return static_cast(kernel.func(input, scale, output, blocks)); } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return neon_decompress_blocks_f64(bit_width, input, scale, output, blocks); -} -#elif defined(PERNIX_BACKEND_ARM64_SVE) -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return sve_compress_block(bit_width, input, scale, output); -} +pernix_status pernix_compress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return sve_compress_block_f64(bit_width, input, scale, output); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return sve_compress_blocks(bit_width, input, scale, output, blocks); -} + const auto kernel = pernix::internal::select_compress_block_f64(static_cast(backend), bit_width, block_size); -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return sve_compress_blocks_f64(bit_width, input, scale, output, blocks); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return sve_decompress_block(bit_width, input, scale, output); + return static_cast(kernel.func(input, scale, output)); } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return sve_decompress_block_f64(bit_width, input, scale, output); -} +pernix_status pernix_compress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output, uint32_t blocks) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return sve_decompress_blocks(bit_width, input, scale, output, blocks); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return sve_decompress_blocks_f64(bit_width, input, scale, output, blocks); -} -#elif defined(PERNIX_BACKEND_ARM64_SVE2) -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return sve2_compress_block(bit_width, input, scale, output); -} - -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return sve2_compress_block_f64(bit_width, input, scale, output); -} - -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return sve2_compress_blocks(bit_width, input, scale, output, blocks); -} + const auto kernel = pernix::internal::select_compress_blocks_f64(static_cast(backend), bit_width, block_size); -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return sve2_compress_blocks_f64(bit_width, input, scale, output, blocks); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return sve2_decompress_block(bit_width, input, scale, output); + return static_cast(kernel.func(input, scale, output, blocks)); } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return sve2_decompress_block_f64(bit_width, input, scale, output); -} +pernix_status pernix_decompress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return sve2_decompress_blocks(bit_width, input, scale, output, blocks); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return sve2_decompress_blocks_f64(bit_width, input, scale, output, blocks); -} -#else -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_decompress_block_f64(static_cast(backend), bit_width, block_size, + sign_values); -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback_f64(bit_width, input, scale, output); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return compress_blocks_fallback(bit_width, input, scale, output, blocks); + return static_cast(kernel.func(input, scale, output)); } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return compress_blocks_fallback_f64(bit_width, input, scale, output, blocks); -} +pernix_status pernix_decompress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, + double scale, void* output, uint32_t blocks, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return decompress_block_fallback(bit_width, input, scale, output); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return decompress_block_fallback_f64(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_decompress_blocks_f64(static_cast(backend), bit_width, block_size, + sign_values); -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return decompress_blocks_fallback(bit_width, input, scale, output, blocks); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return decompress_blocks_fallback_f64(bit_width, input, scale, output, blocks); + return static_cast(kernel.func(input, scale, output, blocks)); } -#endif - -#ifdef __cplusplus } -} // namespace pernix -#endif // __cplusplus diff --git a/src/x86/avx2/avx2_compression.cpp b/src/x86/avx2/avx2_compression.cpp new file mode 100644 index 0000000..1a7f2c0 --- /dev/null +++ b/src/x86/avx2/avx2_compression.cpp @@ -0,0 +1,193 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_block_avx2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_blocks_avx2) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_block_avx2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_blocks_avx2) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +Kernel select_avx2_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"avx2", nullptr}; + } +} + +Kernel select_avx2_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"avx2", nullptr}; + } +} + +Kernel select_avx2_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"avx2", nullptr}; + } +} + +Kernel select_avx2_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"avx2", nullptr}; + } +} + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/x86/avx2/avx2_decompression.cpp b/src/x86/avx2/avx2_decompression.cpp new file mode 100644 index 0000000..43558ae --- /dev/null +++ b/src/x86/avx2/avx2_decompression.cpp @@ -0,0 +1,201 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_block_avx2); \ + return Kernel("avx2", &mm256_decompress_block_avx2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_blocks_avx2); \ + return Kernel("avx2", &mm256_decompress_blocks_avx2) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_block_avx2); \ + return Kernel("avx2", &mm256_decompress_block_avx2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_blocks_avx2); \ + return Kernel("avx2", &mm256_decompress_blocks_avx2) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +Kernel select_avx2_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"avx2", nullptr}; + } +} + +Kernel select_avx2_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"avx2", nullptr}; + } +} + +Kernel select_avx2_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"avx2", nullptr}; + } +} + +Kernel select_avx2_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"avx2", nullptr}; + } +} + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/x86/avx2/compression.cpp b/src/x86/avx2/compression.cpp deleted file mode 100644 index 231fc75..0000000 --- a/src/x86/avx2/compression.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX2_ENABLED - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_compress_block_avx2(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_compress_blocks_avx2(input, scale, output, blocks); - -int mm256_compress_block_avx2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_block_f64_avx2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_avx2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_f64_avx2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED diff --git a/src/x86/avx2/decompression.cpp b/src/x86/avx2/decompression.cpp deleted file mode 100644 index 47c473a..0000000 --- a/src/x86/avx2/decompression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX2_ENABLED -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_decompress_block_avx2(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_decompress_blocks_avx2(input, scale, output, blocks); - -int mm256_decompress_block_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_block_f64_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_f64_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED \ No newline at end of file diff --git a/src/x86/avx512vbmi/avx512vbmi_compression.cpp b/src/x86/avx512vbmi/avx512vbmi_compression.cpp new file mode 100644 index 0000000..f1c36a1 --- /dev/null +++ b/src/x86/avx512vbmi/avx512vbmi_compression.cpp @@ -0,0 +1,193 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_block_avx512vbmi) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_blocks_avx512vbmi) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_block_avx512vbmi) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_blocks_avx512vbmi) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +Kernel select_avx512vbmi_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"avx512vbmi", nullptr}; + } +} + +Kernel select_avx512vbmi_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"avx512vbmi", nullptr}; + } +} + +Kernel select_avx512vbmi_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"avx512vbmi", nullptr}; + } +} + +Kernel select_avx512vbmi_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"avx512vbmi", nullptr}; + } +} + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/x86/avx512vbmi/avx512vbmi_decompression.cpp b/src/x86/avx512vbmi/avx512vbmi_decompression.cpp new file mode 100644 index 0000000..43f45fc --- /dev/null +++ b/src/x86/avx512vbmi/avx512vbmi_decompression.cpp @@ -0,0 +1,201 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +Kernel select_avx512vbmi_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"avx512vbmi", nullptr}; + } +} + +Kernel select_avx512vbmi_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"avx512vbmi", nullptr}; + } +} + +Kernel select_avx512vbmi_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"avx512vbmi", nullptr}; + } +} + +Kernel select_avx512vbmi_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"avx512vbmi", nullptr}; + } +} + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/x86/avx512vbmi/compression.cpp b/src/x86/avx512vbmi/compression.cpp deleted file mode 100644 index bdc6635..0000000 --- a/src/x86/avx512vbmi/compression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_AVX512_VBMI_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm512_compress_block_avx512vbmi(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm512_compress_blocks_avx512vbmi(input, scale, output, blocks); - -int mm512_compress_block_avx512vbmi(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_compress_block_f64_avx512vbmi(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_compress_blocks_avx512vbmi(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm512_compress_blocks_f64_avx512vbmi(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_AVX512_VBMI_ENABLED) \ No newline at end of file diff --git a/src/x86/avx512vbmi/decompression.cpp b/src/x86/avx512vbmi/decompression.cpp deleted file mode 100644 index 1273f6a..0000000 --- a/src/x86/avx512vbmi/decompression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_AVX512_VBMI_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm512_decompress_block_avx512vbmi(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm512_decompress_blocks_avx512vbmi(input, scale, output, blocks); - -int mm512_decompress_block_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_decompress_block_f64_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_decompress_blocks_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm512_decompress_blocks_f64_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED && PERNIX_AVX512_VBMI_ENABLED \ No newline at end of file diff --git a/src/x86/bmi2/bmi2_compression.cpp b/src/x86/bmi2/bmi2_compression.cpp new file mode 100644 index 0000000..ea01e04 --- /dev/null +++ b/src/x86/bmi2/bmi2_compression.cpp @@ -0,0 +1,193 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_block_bmi2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_blocks_bmi2) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_block_bmi2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_blocks_bmi2) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +Kernel select_bmi2_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"bmi2", nullptr}; + } +} + +Kernel select_bmi2_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"bmi2", nullptr}; + } +} + +Kernel select_bmi2_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"bmi2", nullptr}; + } +} + +Kernel select_bmi2_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"bmi2", nullptr}; + } +} + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/x86/bmi2/bmi2_decompression.cpp b/src/x86/bmi2/bmi2_decompression.cpp new file mode 100644 index 0000000..e829c24 --- /dev/null +++ b/src/x86/bmi2/bmi2_decompression.cpp @@ -0,0 +1,201 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_block_bmi2); \ + return Kernel("bmi2", &mm256_decompress_block_bmi2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_blocks_bmi2); \ + return Kernel("bmi2", &mm256_decompress_blocks_bmi2) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_block_bmi2); \ + return Kernel("bmi2", &mm256_decompress_block_bmi2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_blocks_bmi2); \ + return Kernel("bmi2", &mm256_decompress_blocks_bmi2) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +Kernel select_bmi2_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"bmi2", nullptr}; + } +} + +Kernel select_bmi2_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"bmi2", nullptr}; + } +} + +Kernel select_bmi2_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"bmi2", nullptr}; + } +} + +Kernel select_bmi2_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"bmi2", nullptr}; + } +} + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/x86/bmi2/compression.cpp b/src/x86/bmi2/compression.cpp deleted file mode 100644 index 79c43d8..0000000 --- a/src/x86/bmi2/compression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_BMI2_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_compress_block_bmi2(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_compress_blocks_bmi2(input, scale, output, blocks); - -int mm256_compress_block_bmi2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_block_f64_bmi2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_bmi2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_f64_bmi2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED && PERNIX_BMI2_ENABLED diff --git a/src/x86/bmi2/decompression.cpp b/src/x86/bmi2/decompression.cpp deleted file mode 100644 index c9cbad7..0000000 --- a/src/x86/bmi2/decompression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_BMI2_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_decompress_block_bmi2(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_decompress_blocks_bmi2(input, scale, output, blocks); - -int mm256_decompress_block_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_block_f64_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_f64_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED && PERNIX_BMI2_ENABLED \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b5ffe57..62bb90e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,56 +1,7 @@ find_package(PkgConfig) pkg_search_module(GTEST REQUIRED gtest) -include(CheckCXXCompilerFlag) -file(GLOB - PERNIX_ROOT_TEST_SOURCES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp -) - -file(GLOB_RECURSE - PERNIX_FALLBACK_TEST_SOURCES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/fallback/*.cpp -) - -set(SOURCE_FILES ${PERNIX_ROOT_TEST_SOURCES} ${PERNIX_FALLBACK_TEST_SOURCES}) - -if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") - file(GLOB_RECURSE - PERNIX_X86_TEST_SOURCES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/x86/*.cpp - ) - list(APPEND SOURCE_FILES ${PERNIX_X86_TEST_SOURCES}) -elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") - file(GLOB_RECURSE - PERNIX_ARM64_NEON_TEST_SOURCES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/arm64/neon/*.cpp - ) - list(APPEND SOURCE_FILES ${PERNIX_ARM64_NEON_TEST_SOURCES}) -elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") - file(GLOB_RECURSE - PERNIX_ARM64_SVE_TEST_SOURCES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve/*.cpp - ) - list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE_TEST_SOURCES}) -elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") - file(GLOB_RECURSE - PERNIX_ARM64_SVE2_TEST_SOURCES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve2/*.cpp - ) - list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE2_TEST_SOURCES}) -endif () - -file(GLOB_RECURSE - HEADER_FILES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h -) +file(GLOB PERNIX_TEST_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) set(PERNIX_TEST_BLOCK_SIZES 64 128 256 512 CACHE STRING "Block sizes to build as separate test executables") set(PERNIX_TEST_TARGETS) @@ -60,7 +11,7 @@ foreach (BLOCK_SIZE IN LISTS PERNIX_TEST_BLOCK_SIZES) list(APPEND PERNIX_TEST_TARGETS ${TEST_TARGET}) add_executable(${TEST_TARGET}) - target_sources(${TEST_TARGET} PRIVATE ${SOURCE_FILES} ${HEADER_FILES}) + target_sources(${TEST_TARGET} PRIVATE ${PERNIX_TEST_SOURCE_FILES}) target_link_libraries(${TEST_TARGET} PRIVATE pernix ${GTEST_LDFLAGS}) target_compile_features(${TEST_TARGET} PRIVATE cxx_std_20) target_compile_options(${TEST_TARGET} PRIVATE ${GTEST_CFLAGS}) @@ -70,11 +21,6 @@ endforeach () add_custom_target(pernix_tests DEPENDS ${PERNIX_TEST_TARGETS}) -# check_cxx_compiler_flag(-O3 HAS_OPTIMIZE3_FLAG) -# if (HAS_OPTIMIZE3_FLAG) -# target_compile_options(pernix_tests PRIVATE -O3) -# endif () - include(GoogleTest) foreach (BLOCK_SIZE IN LISTS PERNIX_TEST_BLOCK_SIZES) gtest_discover_tests(pernix_tests_${BLOCK_SIZE} @@ -88,13 +34,11 @@ endforeach () if (PERNIX_ENABLE_CODE_COVERAGE) message(STATUS "Code coverage enabled for tests") - # set compile and link options for code coverage foreach (TEST_TARGET IN LISTS PERNIX_TEST_TARGETS) target_compile_options(${TEST_TARGET} PRIVATE -g -O0 --coverage) target_link_options(${TEST_TARGET} PRIVATE --coverage) endforeach () - # find lcov and genhtml find_program(LCOV_PROGRAM lcov) find_program(GENHTML_PROGRAM genhtml) if (NOT LCOV_PROGRAM OR NOT GENHTML_PROGRAM) diff --git a/tests/arm64/neon/decompression_tests.cpp b/tests/arm64/neon/decompression_tests.cpp deleted file mode 100644 index 69229be..0000000 --- a/tests/arm64/neon/decompression_tests.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include - -#ifdef PERNIX_BACKEND_ARM64_NEON - -using namespace pernix::arm64::neon; - -TYPED_TEST(DecompressionTest, NeonDecompressBlock) { - std::vector > decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - neon_decompress_block( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, NeonDecompressBlock) { - std::vector > decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - neon_decompress_block( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif \ No newline at end of file diff --git a/tests/arm64/sve/.gitkeep b/tests/arm64/sve/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/tests/arm64/sve2/decompression_tests.cpp b/tests/arm64/sve2/decompression_tests.cpp deleted file mode 100644 index 82cb11f..0000000 --- a/tests/arm64/sve2/decompression_tests.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include - -#ifdef PERNIX_BACKEND_ARM64_SVE2 - -using namespace pernix::arm64::sve2; - -TYPED_TEST(DecompressionTest, SVE2DecompressBlock) { - std::vector > decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - sve2_decompress_block( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, SVE2DecompressBlock) { - std::vector > decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - sve2_decompress_block( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif \ No newline at end of file diff --git a/tests/fallback/compression_tests.cpp b/tests/fallback/compression_tests.cpp deleted file mode 100644 index 78249d7..0000000 --- a/tests/fallback/compression_tests.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include -#include - -TYPED_TEST(CompressionTest, FallbackCompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::compress_block_fallback( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectCompressedBlockEqualsReference(*this, compressedData[block], block); - } -} - -TYPED_TEST(CompressionTest64, FallbackCompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::compress_block_fallback( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectCompressedBlockEqualsReference(*this, compressedData[block], block); - } -} diff --git a/tests/fallback/decompression_tests.cpp b/tests/fallback/decompression_tests.cpp deleted file mode 100644 index 08c26c4..0000000 --- a/tests/fallback/decompression_tests.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include - -TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::decompress_block_fallback( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::decompress_block_fallback( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} diff --git a/tests/fallback/edge_tests.cpp b/tests/fallback/edge_tests.cpp deleted file mode 100644 index 2bdee5a..0000000 --- a/tests/fallback/edge_tests.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include -#include - -#include - -#include -#include -#include -#include - -TEST(FallbackDecompressionEdgeTest, SignExtensionIsWellDefinedForNegativeValues) { - const std::array input{0x08}; - std::array output{}; - - ASSERT_EQ(pernix::decompress_block_fallback<4>(input.data(), 1.0F, output.data()), 0); - - EXPECT_EQ(output[0], -8.0F); -} - -TEST(FallbackCompressionEdgeTest, ClearsUnusedPaddingBytes) { - std::array input{}; - std::array output{}; - output.fill(0xA5); - - ASSERT_EQ(pernix::compress_block_fallback<24>(input.data(), 1.0F, output.data()), 0); - - EXPECT_EQ(output[63], 0); -} - -TEST(FallbackCompressionEdgeTest, ClampsNonFiniteAndOutOfRangeBeforeNarrowing) { - std::array input{}; - input[0] = std::numeric_limits::infinity(); - input[1] = -std::numeric_limits::infinity(); - input[2] = std::numeric_limits::quiet_NaN(); - std::array compressed{}; - std::array restored{}; - - ASSERT_EQ(pernix::compress_block_fallback<4>(input.data(), 1.0F, compressed.data()), 0); - ASSERT_EQ(pernix::decompress_block_fallback<4>(compressed.data(), 1.0F, restored.data()), 0); - - EXPECT_EQ(restored[0], 7.0F); - EXPECT_EQ(restored[1], -8.0F); - EXPECT_EQ(restored[2], 0.0F); -} diff --git a/tests/fallback_tests.cpp b/tests/fallback_tests.cpp new file mode 100644 index 0000000..5e2b4ef --- /dev/null +++ b/tests/fallback_tests.cpp @@ -0,0 +1,316 @@ +#include + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Fallback compress: verify byte-exact match against the reference +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, FallbackCompressBlock) { + std::vector> compressed(this->testSet.numberOfBlocks); + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + compressed[b].resize(TestFixture::BlockSize); + + const auto status = pernix_compress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getDecompressedData()[b].data(), 1.0f / this->testSet.getScales()[b], + compressed[b].data()); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + expectCompressedBlockEqualsReference(*this, compressed[b], b); + } +} + +TYPED_TEST(CompressionTest64, FallbackCompressBlock) { + std::vector> compressed(this->testSet.numberOfBlocks); + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + compressed[b].resize(TestFixture::BlockSize); + + const auto status = pernix_compress_block_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getDecompressedData()[b].data(), 1.0 / this->testSet.getScales()[b], + compressed[b].data()); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + expectCompressedBlockEqualsReference(*this, compressed[b], b); + } +} + +// --------------------------------------------------------------------------- +// Fallback decompress: verify near-source match +// --------------------------------------------------------------------------- + +TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { + std::vector> decompressed(this->testSet.numberOfBlocks); + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + decompressed[b].resize(this->testSet.elementsPerBlock); + + const auto status = pernix_decompress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getCompressedData()[b].data(), this->testSet.getScales()[b], + decompressed[b].data(), true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + expectDecompressedBlockNearSource(*this, decompressed[b], b); + } +} + +TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { + std::vector> decompressed(this->testSet.numberOfBlocks); + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + decompressed[b].resize(this->testSet.elementsPerBlock); + + const auto status = pernix_decompress_block_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getCompressedData()[b].data(), this->testSet.getScales()[b], + decompressed[b].data(), true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + expectDecompressedBlockNearSource(*this, decompressed[b], b); + } +} + +// --------------------------------------------------------------------------- +// Multi-block roundtrip via compress_blocks / decompress_blocks (fallback) +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, FallbackCompressBlocksRoundtrip) { + const uint32_t nb = this->testSet.numberOfBlocks; + const uint32_t epb = this->testSet.elementsPerBlock; + const uint32_t total = nb * epb; + + std::vector flat(total); + for (uint32_t b = 0; b < nb; b++) { + std::copy_n(this->testSet.getDecompressedData()[b].data(), epb, + flat.data() + b * epb); + } + + // Compute a single scale that covers all blocks + float max_abs = 0.0f; + for (uint32_t i = 0; i < total; i++) { + max_abs = std::max(max_abs, std::abs(flat[i])); + } + const float q = static_cast(decltype(this->testSet)::quantization_levels); + const float scale = (max_abs > 0.0f && q > 0.0f) ? (max_abs / q) : std::numeric_limits::epsilon(); + const float scale_inv = 1.0f / scale; + + std::vector compressed(nb * TestFixture::BlockSize); + auto status = pernix_compress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + flat.data(), scale_inv, compressed.data(), nb); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + std::vector restored(total); + status = pernix_decompress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, restored.data(), nb, true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + const float tol = (std::abs(scale) * 0.5f) + (std::numeric_limits::epsilon() * 16.0f); + for (uint32_t i = 0; i < total; i++) { + EXPECT_NEAR(restored[i], flat[i], tol); + } +} + +TYPED_TEST(CompressionTest64, FallbackCompressBlocksRoundtrip) { + const uint32_t nb = this->testSet.numberOfBlocks; + const uint32_t epb = this->testSet.elementsPerBlock; + const uint32_t total = nb * epb; + + std::vector flat(total); + for (uint32_t b = 0; b < nb; b++) { + std::copy_n(this->testSet.getDecompressedData()[b].data(), epb, + flat.data() + b * epb); + } + + // Compute a single scale that covers all blocks + double max_abs = 0.0; + for (uint32_t i = 0; i < total; i++) { + max_abs = std::max(max_abs, std::abs(flat[i])); + } + const double q = static_cast(decltype(this->testSet)::quantization_levels); + const double scale = (max_abs > 0.0 && q > 0.0) ? (max_abs / q) : std::numeric_limits::epsilon(); + const double scale_inv = 1.0 / scale; + + std::vector compressed(nb * TestFixture::BlockSize); + auto status = pernix_compress_blocks_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + flat.data(), scale_inv, compressed.data(), nb); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + std::vector restored(total); + status = pernix_decompress_blocks_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, restored.data(), nb, true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + const double tol = (std::abs(scale) * 0.5) + (std::numeric_limits::epsilon() * 16.0); + for (uint32_t i = 0; i < total; i++) { + EXPECT_NEAR(restored[i], flat[i], tol); + } +} + +// --------------------------------------------------------------------------- +// blocks API with a single block should match the block API exactly +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, SingleBlockCompressBlocksMatchesBlock) { + const auto& src = this->testSet.getDecompressedData()[0]; + const float scale_inv = 1.0f / this->testSet.getScales()[0]; + + std::vector blockOut(TestFixture::BlockSize); + std::vector blocksOut(TestFixture::BlockSize); + + auto s1 = pernix_compress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + src.data(), scale_inv, blockOut.data()); + ASSERT_EQ(s1, PERNIX_STATUS_OK); + + auto s2 = pernix_compress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + src.data(), scale_inv, blocksOut.data(), 1); + ASSERT_EQ(s2, PERNIX_STATUS_OK); + + for (uint32_t i = 0; i < TestFixture::BlockSize; i++) { + EXPECT_EQ(blockOut[i], blocksOut[i]) << "byte " << i; + } +} + +TYPED_TEST(DecompressionTest, SingleBlockDecompressBlocksMatchesBlock) { + const auto& compressed = this->testSet.getCompressedData()[0]; + const float scale = this->testSet.getScales()[0]; + const uint32_t epb = this->testSet.elementsPerBlock; + + std::vector blockOut(epb); + std::vector blocksOut(epb); + + auto s1 = pernix_decompress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, blockOut.data(), true); + ASSERT_EQ(s1, PERNIX_STATUS_OK); + + auto s2 = pernix_decompress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, blocksOut.data(), 1, true); + ASSERT_EQ(s2, PERNIX_STATUS_OK); + + for (uint32_t i = 0; i < epb; i++) { + EXPECT_FLOAT_EQ(blockOut[i], blocksOut[i]) << "element " << i; + } +} + +// --------------------------------------------------------------------------- +// Edge-case behavioural tests (fallback, block_size=64) +// --------------------------------------------------------------------------- + +TEST(FallbackEdgeTest, SignExtensionIsWellDefinedForNegativeValues) { + constexpr uint32_t BS = 64; + const std::array input{0x08}; + + pernix_status st; + std::array output{}; + + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, 4, BS, input.data(), 1.0f, output.data(), true); + ASSERT_EQ(st, PERNIX_STATUS_OK); + EXPECT_EQ(output[0], -8.0f); +} + +TEST(FallbackEdgeTest, ClearsUnusedPaddingBytes) { + constexpr uint32_t BS = 64; + constexpr uint32_t BW = 24; + constexpr uint32_t EPB = (BS * 8) / BW; + + std::array input{}; + std::array output{}; + output.fill(0xA5); + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, input.data(), 1.0f, output.data()); + ASSERT_EQ(st, PERNIX_STATUS_OK); + EXPECT_EQ(output[BS - 1], 0); +} + +TEST(FallbackEdgeTest, ClampsNonFiniteAndOutOfRangeBeforeNarrowing) { + constexpr uint32_t BS = 64; + constexpr uint32_t BW = 4; + constexpr uint32_t EPB = (BS * 8) / BW; + + std::array input{}; + input[0] = std::numeric_limits::infinity(); + input[1] = -std::numeric_limits::infinity(); + input[2] = std::numeric_limits::quiet_NaN(); + + std::array compressed{}; + std::array restored{}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, input.data(), 1.0f, compressed.data()); + ASSERT_EQ(st, PERNIX_STATUS_OK); + + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, compressed.data(), 1.0f, restored.data(), true); + ASSERT_EQ(st, PERNIX_STATUS_OK); + + EXPECT_EQ(restored[0], 7.0f); + EXPECT_EQ(restored[1], -8.0f); + EXPECT_EQ(restored[2], 0.0f); +} + +// --------------------------------------------------------------------------- +// Error-code contract tests +// --------------------------------------------------------------------------- + +TEST(ErrorCodeTest, UnsupportedBlockSizeReturnsError) { + constexpr uint32_t BS = 32; + float_t src[32] = {}; + uint8_t dst[32] = {}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, BS, src, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); + + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, 8, BS, dst, 1.0f, src, true); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); + + st = pernix_compress_blocks_f32(PERNIX_BACKEND_FALLBACK, 8, BS, src, 1.0f, dst, 1); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); + + st = pernix_decompress_blocks_f32(PERNIX_BACKEND_FALLBACK, 8, BS, dst, 1.0f, src, 1, true); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); +} + +TEST(ErrorCodeTest, UnsupportedBitWidthReturnsError) { + constexpr uint32_t BS = 64; + float_t src[256] = {}; + uint8_t dst[64] = {}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 0, BS, src, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH); + + st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 25, BS, src, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH); +} + +TEST(ErrorCodeTest, NullPointerReturnsError) { + float_t src[64] = {}; + uint8_t dst[64] = {}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, 64, nullptr, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_INVALID_ARGUMENT); + + st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, 64, src, 1.0f, nullptr); + EXPECT_EQ(st, PERNIX_STATUS_INVALID_ARGUMENT); +} diff --git a/tests/include/testset.h b/tests/include/testset.h index 5ef4fa1..81acb32 100644 --- a/tests/include/testset.h +++ b/tests/include/testset.h @@ -21,22 +21,14 @@ static_assert(PERNIX_TEST_BLOCK_SIZE % 32 == 0, "PERNIX_TEST_BLOCK_SIZE must be dividable by 32 bytes"); -/** - * A test set for compression and decompression tests. - * It generates random float data, compresses it, and verifies the decompression using the fallback implementation. - * - * @tparam BIT_WIDTH The bit width used for compression (1 to 24). - * @tparam SIGN_VALUES Indicates whether the values are signed or unsigned. - */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && std::is_floating_point_v class TestSet { - // using ValueType = std::conditional_t; using ValueType = uint8_t; using SeedType = std::mt19937::result_type; alignas(64) std::vector > compressedData; - alignas(64) std::vector > decompressedData; + alignas(64) std::vector > sourceData; alignas(64) std::vector scalesData; SeedType seed; @@ -55,14 +47,13 @@ class TestSet { [[nodiscard]] constexpr uint32_t totalElements() const { return numberOfBlocks * elementsPerBlock; } [[nodiscard]] T blockTolerance(const uint32_t block) const { - // Half-step quantization bound + tiny FP slack for rounding edge cases. return (std::abs(scalesData[block]) * static_cast(0.5)) + (std::numeric_limits::epsilon() * static_cast(16)); } explicit TestSet(const uint32_t number_of_blocks, const SeedType initial_seed = testSeed()) : seed(initial_seed), gen(seed), numberOfBlocks(number_of_blocks) { compressedData.resize(numberOfBlocks); - decompressedData.resize(number_of_blocks); + sourceData.resize(number_of_blocks); scalesData.resize(numberOfBlocks); generateData(); @@ -72,7 +63,7 @@ class TestSet { [[nodiscard]] const std::vector >& getCompressedData() const { return compressedData; } - [[nodiscard]] const std::vector >& getDecompressedData() const { return decompressedData; } + [[nodiscard]] const std::vector >& getDecompressedData() const { return sourceData; } [[nodiscard]] SeedType getSeed() const { return seed; } @@ -88,26 +79,29 @@ class TestSet { } private: - // Generate deterministic source data and its fallback-compressed reference. void generateData() { for (uint32_t i = 0; i < numberOfBlocks; i++) { compressedData[i].resize(BLOCK_SIZE); - decompressedData[i].resize(elementsPerBlock); + sourceData[i].resize(elementsPerBlock); for (uint32_t j = 0; j < elementsPerBlock; j++) { - decompressedData[i][j] = dis(gen); + sourceData[i][j] = dis(gen); } - const T b_max = *std::ranges::max_element(decompressedData[i]); - const T b_min = *std::ranges::min_element(decompressedData[i]); + const T b_max = *std::ranges::max_element(sourceData[i]); + const T b_min = *std::ranges::min_element(sourceData[i]); const T b_abs = std::max(std::abs(b_max), std::abs(b_min)); scalesData[i] = (b_abs > static_cast(0) && quantization_levels > static_cast(0)) ? (b_abs / quantization_levels) : std::numeric_limits::epsilon(); - // Compress the data using the fallback implementation - pernix::compress_block_fallback(decompressedData[i].data(), 1 / scalesData[i], - reinterpret_cast(compressedData[i].data())); + if constexpr (std::is_same_v) { + pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BIT_WIDTH, BLOCK_SIZE, + sourceData[i].data(), 1.0f / scalesData[i], compressedData[i].data()); + } else { + pernix_compress_block_f64(PERNIX_BACKEND_FALLBACK, BIT_WIDTH, BLOCK_SIZE, + sourceData[i].data(), 1.0 / scalesData[i], compressedData[i].data()); + } } } }; diff --git a/tests/simd_tests.cpp b/tests/simd_tests.cpp new file mode 100644 index 0000000..6f9bc50 --- /dev/null +++ b/tests/simd_tests.cpp @@ -0,0 +1,188 @@ +#include + +#include +#include + +// --------------------------------------------------------------------------- +// SIMD compress: compress via backend, decompress via fallback, compare source +// --------------------------------------------------------------------------- + +template +void testBackendCompressBlock(FixtureT& fixture, pernix_backend backend) { + using T = std::remove_cvref_t; + + { + std::vector probe(FixtureT::BlockSize); + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_compress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[0].data(), + 1.0f / fixture.testSet.getScales()[0], probe.data()); + } else { + st = pernix_compress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[0].data(), + 1.0 / fixture.testSet.getScales()[0], probe.data()); + } + if (st != PERNIX_STATUS_OK) { + SUCCEED(); + return; + } + } + + for (uint32_t b = 0; b < fixture.testSet.numberOfBlocks; b++) { + std::vector compressed(FixtureT::BlockSize); + + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_compress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[b].data(), + 1.0f / fixture.testSet.getScales()[b], compressed.data()); + } else { + st = pernix_compress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[b].data(), + 1.0 / fixture.testSet.getScales()[b], compressed.data()); + } + ASSERT_EQ(st, PERNIX_STATUS_OK); + + std::vector restored(fixture.testSet.elementsPerBlock); + if constexpr (std::is_same_v) { + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, FixtureT::BitWidth, FixtureT::BlockSize, + compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); + } else { + st = pernix_decompress_block_f64(PERNIX_BACKEND_FALLBACK, FixtureT::BitWidth, FixtureT::BlockSize, + compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); + } + ASSERT_EQ(st, PERNIX_STATUS_OK); + + expectDecompressedBlockNearSource(fixture, restored, b); + } +} + +// --------------------------------------------------------------------------- +// SIMD decompress: decompress fallback-compressed data via backend, compare source +// --------------------------------------------------------------------------- + +template +void testBackendDecompressBlock(FixtureT& fixture, pernix_backend backend) { + using T = std::remove_cvref_t; + + { + std::vector probe(fixture.testSet.elementsPerBlock); + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_decompress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[0].data(), + fixture.testSet.getScales()[0], probe.data(), true); + } else { + st = pernix_decompress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[0].data(), + fixture.testSet.getScales()[0], probe.data(), true); + } + if (st != PERNIX_STATUS_OK) { + SUCCEED(); + return; + } + } + + for (uint32_t b = 0; b < fixture.testSet.numberOfBlocks; b++) { + std::vector decompressed(fixture.testSet.elementsPerBlock); + + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_decompress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[b].data(), + fixture.testSet.getScales()[b], decompressed.data(), true); + } else { + st = pernix_decompress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[b].data(), + fixture.testSet.getScales()[b], decompressed.data(), true); + } + ASSERT_EQ(st, PERNIX_STATUS_OK); + + expectDecompressedBlockNearSource(fixture, decompressed, b); + } +} + +// --------------------------------------------------------------------------- +// x86: AVX2 +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, AVX2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +TYPED_TEST(DecompressionTest, AVX2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +TYPED_TEST(CompressionTest64, AVX2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +TYPED_TEST(DecompressionTest64, AVX2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +// --------------------------------------------------------------------------- +// x86: BMI2 +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, BMI2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +TYPED_TEST(DecompressionTest, BMI2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +TYPED_TEST(CompressionTest64, BMI2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +TYPED_TEST(DecompressionTest64, BMI2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +// --------------------------------------------------------------------------- +// x86: AVX512-VBMI +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, AVX512VBMICompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +TYPED_TEST(DecompressionTest, AVX512VBMIDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +TYPED_TEST(CompressionTest64, AVX512VBMICompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +TYPED_TEST(DecompressionTest64, AVX512VBMIDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +// --------------------------------------------------------------------------- +// ARM64: NEON (decompress only — no compress implementation) +// --------------------------------------------------------------------------- + +TYPED_TEST(DecompressionTest, NeonDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_NEON); +} + +TYPED_TEST(DecompressionTest64, NeonDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_NEON); +} + +// --------------------------------------------------------------------------- +// ARM64: SVE2 (decompress only — no compress implementation) +// --------------------------------------------------------------------------- + +TYPED_TEST(DecompressionTest, SVE2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_SVE); +} + +TYPED_TEST(DecompressionTest64, SVE2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_SVE); +} diff --git a/tests/x86/avx2/compression_tests.cpp b/tests/x86/avx2/compression_tests.cpp deleted file mode 100644 index bd7f683..0000000 --- a/tests/x86/avx2/compression_tests.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX2_ENABLED - -TYPED_TEST(CompressionTest, AVX2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_avx2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -TYPED_TEST(CompressionTest64, AVX2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_avx2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -#endif // PERNIX_AVX2_ENABLED diff --git a/tests/x86/avx2/decompression_tests.cpp b/tests/x86/avx2/decompression_tests.cpp deleted file mode 100644 index a6fc2c5..0000000 --- a/tests/x86/avx2/decompression_tests.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX2_ENABLED - -TYPED_TEST(DecompressionTest, AVX2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_avx2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, AVX2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_avx2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif // PERNIX_AVX2_ENABLED diff --git a/tests/x86/avx512vbmi/compression_tests.cpp b/tests/x86/avx512vbmi/compression_tests.cpp deleted file mode 100644 index a6cb71d..0000000 --- a/tests/x86/avx512vbmi/compression_tests.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX512_VBMI_ENABLED - -TYPED_TEST(CompressionTest, AVX512VBMICompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm512_compress_block_avx512vbmi( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -TYPED_TEST(CompressionTest64, AVX512VBMICompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm512_compress_block_avx512vbmi( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -#endif // PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/x86/avx512vbmi/decompression_tests.cpp b/tests/x86/avx512vbmi/decompression_tests.cpp deleted file mode 100644 index f44dd8d..0000000 --- a/tests/x86/avx512vbmi/decompression_tests.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX512_VBMI_ENABLED - -TYPED_TEST(DecompressionTest, AVX512VBMIDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm512_decompress_block_avx512vbmi( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, AVX512VBMIDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm512_decompress_block_avx512vbmi( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif // PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/x86/bmi2/compression_tests.cpp b/tests/x86/bmi2/compression_tests.cpp deleted file mode 100644 index b7fc2fd..0000000 --- a/tests/x86/bmi2/compression_tests.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include - -#ifdef PERNIX_BMI2_ENABLED - -TYPED_TEST(CompressionTest, BMI2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_bmi2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -TYPED_TEST(CompressionTest64, BMI2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_bmi2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -#endif // PERNIX_BMI2_ENABLED diff --git a/tests/x86/bmi2/decompression_tests.cpp b/tests/x86/bmi2/decompression_tests.cpp deleted file mode 100644 index dd7efc1..0000000 --- a/tests/x86/bmi2/decompression_tests.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include -#include - -#ifdef PERNIX_BMI2_ENABLED - -TYPED_TEST(DecompressionTest, BMI2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_bmi2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, BMI2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_bmi2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif // PERNIX_BMI2_ENABLED From 8ad45acbbe599183a2f0b44179ee49281b0a40f5 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Mon, 15 Jun 2026 15:29:06 +0200 Subject: [PATCH 14/14] replace number types with pre defined types --- include/pernix/compat.h | 21 + include/pernix/pernix.h | 34 +- include/pernix/pernix.hpp | 127 +- src/arm64/neon/compression.cpp | 40 +- src/arm64/neon/decompression.cpp | 39 +- src/arm64/sve2/compression.cpp | 40 +- src/arm64/sve2/decompression.cpp | 95 +- src/dispatch/cpu_features_arm.cpp | 21 +- src/dispatch/cpu_features_x86.cpp | 8 +- src/dispatch/select.cpp | 820 ++++++------ src/fallback/fallback_compression.cpp | 80 +- src/fallback/fallback_decompression.cpp | 76 +- src/internal/pernix/arm64/neon/common.h | 74 +- src/internal/pernix/arm64/neon/compression.h | 204 +-- src/internal/pernix/arm64/neon/tables.h | 358 ++--- src/internal/pernix/arm64/neon/unpacking.h | 8 +- src/internal/pernix/arm64/sve2/compression.h | 72 +- src/internal/pernix/arm64/sve2/packing.h | 2 +- src/internal/pernix/arm64/sve2/tables.h | 208 ++- src/internal/pernix/arm64/sve2/unpacking.h | 126 +- src/internal/pernix/dispatch/kernel.h | 10 +- src/internal/pernix/dispatch/select.h | 140 +- .../pernix/fallback/avx2_compression.h | 313 ++--- .../pernix/fallback/avx2_decompression.h | 292 ++-- src/internal/pernix/simd_compat.h | 13 - .../pernix/x86/avx2/avx2_compression.h | 794 +++++------ .../pernix/x86/avx2/avx2_decompression.h | 451 ++++--- src/internal/pernix/x86/avx2/avx2_tables.h | 480 +++---- .../x86/avx512vbmi/avx512vbmi_compression.h | 1177 +++++++++-------- .../x86/avx512vbmi/avx512vbmi_decompression.h | 948 ++++++------- src/internal/pernix/x86/avx512vbmi/compat.h | 188 +-- src/internal/pernix/x86/avx512vbmi/packing.h | 42 +- src/internal/pernix/x86/avx512vbmi/tables.h | 640 ++++----- .../pernix/x86/avx512vbmi/unpacking.h | 52 +- .../pernix/x86/bmi2/bmi2_compression.h | 438 +++--- .../pernix/x86/bmi2/bmi2_decompression.h | 492 +++---- src/internal/pernix/x86/utils.h | 14 +- src/pernix.cpp | 70 +- src/x86/avx2/avx2_compression.cpp | 80 +- src/x86/avx2/avx2_decompression.cpp | 76 +- src/x86/avx512vbmi/avx512vbmi_compression.cpp | 80 +- .../avx512vbmi/avx512vbmi_decompression.cpp | 76 +- src/x86/bmi2/bmi2_compression.cpp | 80 +- src/x86/bmi2/bmi2_decompression.cpp | 76 +- tests/fallback_tests.cpp | 136 +- tests/include/testset.h | 146 +- tests/simd_tests.cpp | 56 +- 47 files changed, 5046 insertions(+), 4767 deletions(-) diff --git a/include/pernix/compat.h b/include/pernix/compat.h index 42544ee..498c83a 100644 --- a/include/pernix/compat.h +++ b/include/pernix/compat.h @@ -1,6 +1,10 @@ #ifndef PERNIX_COMPAT_H #define PERNIX_COMPAT_H +#include +#include +#include + #ifndef __always_inline #if defined(__GNUC__) || defined(__clang__) #define __always_inline inline __attribute__((always_inline)) @@ -21,4 +25,21 @@ #define PERNIX_API #endif +// Convenient type declarations +typedef uint8_t u8; +typedef uint16_t u16; +typedef uint32_t u32; +typedef uint64_t u64; +typedef uintptr_t uptr; + +typedef int8_t i8; +typedef int16_t i16; +typedef int32_t i32; +typedef int64_t i64; + +typedef float_t f32; +typedef double_t f64; + +typedef size_t usize; + #endif //PERNIX_COMPAT_H diff --git a/include/pernix/pernix.h b/include/pernix/pernix.h index 7a64af6..74d586b 100644 --- a/include/pernix/pernix.h +++ b/include/pernix/pernix.h @@ -2,8 +2,6 @@ #define PERNIX_H #include -#include -#include #if defined(__cplusplus) extern "C" { @@ -27,29 +25,37 @@ typedef enum pernix_backend { PERNIX_BACKEND_ARM64_SVE = 6 } pernix_backend; -PERNIX_API pernix_status pernix_compress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, +PERNIX_API pernix_status pernix_compress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, float scale, void* output); -PERNIX_API pernix_status pernix_compress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - float scale, void* output, uint32_t blocks); +PERNIX_API pernix_status pernix_compress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + float scale, void* output, u32 blocks); -PERNIX_API pernix_status pernix_decompress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, +PERNIX_API pernix_status pernix_decompress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, float scale, void* output, bool sign_values); -PERNIX_API pernix_status pernix_decompress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - float scale, void* output, uint32_t blocks, bool sign_values); +PERNIX_API pernix_status pernix_decompress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + float scale, void* output, u32 blocks, bool sign_values); -PERNIX_API pernix_status pernix_compress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, +PERNIX_API pernix_status pernix_compress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, double scale, void* output); -PERNIX_API pernix_status pernix_compress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - double scale, void* output, uint32_t blocks); +PERNIX_API pernix_status pernix_compress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + double scale, void* output, u32 blocks); -PERNIX_API pernix_status pernix_decompress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, +PERNIX_API pernix_status pernix_decompress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, double scale, void* output, bool sign_values); -PERNIX_API pernix_status pernix_decompress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - double scale, void* output, uint32_t blocks, bool sign_values); +PERNIX_API pernix_status pernix_decompress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + double scale, void* output, u32 blocks, bool sign_values); #if defined(__cplusplus) } diff --git a/include/pernix/pernix.hpp b/include/pernix/pernix.hpp index 67301bb..d197f72 100644 --- a/include/pernix/pernix.hpp +++ b/include/pernix/pernix.hpp @@ -1,5 +1,7 @@ #ifndef PERNIX_HPP #define PERNIX_HPP + +#include #include namespace pernix { @@ -13,139 +15,150 @@ enum class Backend { Arm64Sve = PERNIX_BACKEND_ARM64_SVE }; -__always_inline int compress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const float scale, std::span output) { - return pernix_compress_block_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data()); +__always_inline int compress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output) { + return pernix_compress_block_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data()); } -__always_inline int compress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const double scale, std::span output) { - return pernix_compress_block_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data()); +__always_inline int compress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output) { + return pernix_compress_block_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data()); } -__always_inline int decompress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const float scale, std::span output, +__always_inline int decompress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output, const bool sign_values = true) { - return pernix_decompress_block_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + return pernix_decompress_block_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), sign_values); } -__always_inline int decompress_block(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const double scale, std::span output, +__always_inline int decompress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output, const bool sign_values = true) { - return pernix_decompress_block_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), + return pernix_decompress_block_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), sign_values); } -__always_inline int compress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const float scale, std::span output, - const uint32_t blocks) { - return pernix_compress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), +__always_inline int compress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output, + const u32 blocks) { + return pernix_compress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), blocks); } -__always_inline int compress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const double scale, std::span output, - const uint32_t blocks) { - return pernix_compress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), +__always_inline int compress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output, + const u32 blocks) { + return pernix_compress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), blocks); } -__always_inline int decompress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const float scale, std::span output, - const uint32_t blocks, const bool sign_values = true) { - return pernix_decompress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), +__always_inline int decompress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output, + const u32 blocks, const bool sign_values = true) { + return pernix_decompress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), blocks, sign_values); } -__always_inline int decompress_blocks(Backend backend, const uint8_t bit_width, const uint32_t block_size, - const std::span input, const double scale, std::span output, - const uint32_t blocks, const bool sign_values = true) { - return pernix_decompress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), scale, output.data(), +__always_inline int decompress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output, + const u32 blocks, const bool sign_values = true) { + return pernix_decompress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), blocks, sign_values); } // convenience overloads without backend (defaults to Auto) -__always_inline int compress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const float scale, const std::span output) { +__always_inline int compress_block(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output) { return compress_block(Backend::Auto, bit_width, block_size, input, scale, output); } -__always_inline int compress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const double scale, const std::span output) { +__always_inline int compress_block(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output) { return compress_block(Backend::Auto, bit_width, block_size, input, scale, output); } -__always_inline int decompress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, +__always_inline int decompress_block(const u8 bit_width, const u32 block_size, const std::span input, const float scale, const std::span output, const bool sign_values = true) { return decompress_block(Backend::Auto, bit_width, block_size, input, scale, output, sign_values); } -__always_inline int decompress_block(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const double scale, const std::span output, const bool sign_values = true) { +__always_inline int decompress_block(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output, + const bool sign_values = true) { return decompress_block(Backend::Auto, bit_width, block_size, input, scale, output, sign_values); } -__always_inline int compress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const float scale, const std::span output, const uint32_t blocks) { +__always_inline int compress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output, const u32 blocks) { return compress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks); } -__always_inline int compress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const double scale, const std::span output, const uint32_t blocks) { +__always_inline int compress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output, const u32 blocks) { return compress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks); } -__always_inline int decompress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const float scale, const std::span output, const uint32_t blocks, +__always_inline int decompress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output, const u32 blocks, const bool sign_values = true) { return decompress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks, sign_values); } -__always_inline int decompress_blocks(const uint8_t bit_width, const uint32_t block_size, const std::span input, - const double scale, const std::span output, const uint32_t blocks, +__always_inline int decompress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output, const u32 blocks, const bool sign_values = true) { return decompress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks, sign_values); } // convenience overloads without backend and without block_size (defaults to 64) -__always_inline int compress_block(const uint8_t bit_width, const std::span input, const float scale, - const std::span output) { +__always_inline int compress_block(const u8 bit_width, const std::span input, const float scale, + const std::span output) { return compress_block(Backend::Auto, bit_width, 64, input, scale, output); } -__always_inline int compress_block(const uint8_t bit_width, const std::span input, const double scale, - const std::span output) { +__always_inline int compress_block(const u8 bit_width, const std::span input, const double scale, + const std::span output) { return compress_block(Backend::Auto, bit_width, 64, input, scale, output); } -__always_inline int decompress_block(const uint8_t bit_width, const std::span input, const float scale, +__always_inline int decompress_block(const u8 bit_width, const std::span input, const float scale, const std::span output, const bool sign_values = true) { return decompress_block(Backend::Auto, bit_width, 64, input, scale, output, sign_values); } -__always_inline int decompress_block(const uint8_t bit_width, const std::span input, const double scale, +__always_inline int decompress_block(const u8 bit_width, const std::span input, const double scale, const std::span output, const bool sign_values = true) { return decompress_block(Backend::Auto, bit_width, 64, input, scale, output, sign_values); } -__always_inline int compress_blocks(const uint8_t bit_width, const std::span input, const float scale, - const std::span output, const uint32_t blocks) { +__always_inline int compress_blocks(const u8 bit_width, const std::span input, const float scale, + const std::span output, const u32 blocks) { return compress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks); } -__always_inline int compress_blocks(const uint8_t bit_width, const std::span input, const double scale, - const std::span output, const uint32_t blocks) { +__always_inline int compress_blocks(const u8 bit_width, const std::span input, const double scale, + const std::span output, const u32 blocks) { return compress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks); } -__always_inline int decompress_blocks(const uint8_t bit_width, const std::span input, const float scale, - const std::span output, const uint32_t blocks, const bool sign_values = true) { +__always_inline int decompress_blocks(const u8 bit_width, const std::span input, const float scale, + const std::span output, const u32 blocks, + const bool sign_values = true) { return decompress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks, sign_values); } -__always_inline int decompress_blocks(const uint8_t bit_width, const std::span input, const double scale, - const std::span output, const uint32_t blocks, const bool sign_values = true) { +__always_inline int decompress_blocks(const u8 bit_width, const std::span input, const double scale, + const std::span output, const u32 blocks, + const bool sign_values = true) { return decompress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks, sign_values); } } diff --git a/src/arm64/neon/compression.cpp b/src/arm64/neon/compression.cpp index 81c1cc2..3594303 100644 --- a/src/arm64/neon/compression.cpp +++ b/src/arm64/neon/compression.cpp @@ -2,27 +2,27 @@ #include namespace pernix::internal { -Kernel select_neon_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"neon", nullptr}; -} + Kernel select_neon_compress_block_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } -Kernel select_neon_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"neon", nullptr}; -} + Kernel select_neon_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } -Kernel select_neon_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"neon", nullptr}; -} + Kernel select_neon_compress_block_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } -Kernel select_neon_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"neon", nullptr}; -} + Kernel select_neon_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } } diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp index 94da2fb..3083926 100644 --- a/src/arm64/neon/decompression.cpp +++ b/src/arm64/neon/decompression.cpp @@ -1,26 +1,29 @@ #include #include +using pernix::arm64::neon::neon_decompress_block; +using pernix::arm64::neon::neon_decompress_blocks; + namespace pernix::internal { #define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ case N: \ - if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_block); \ - return Kernel("neon", &arm64::neon::neon_decompress_block) + if (sign_values) return Kernel("neon", &neon_decompress_block); \ + return Kernel("neon", &neon_decompress_block) #define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ case N: \ - if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_blocks); \ - return Kernel("neon", &arm64::neon::neon_decompress_blocks) + if (sign_values) return Kernel("neon", &neon_decompress_blocks); \ + return Kernel("neon", &neon_decompress_blocks) #define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ case N: \ - if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_block); \ - return Kernel("neon", &arm64::neon::neon_decompress_block) + if (sign_values) return Kernel("neon", &neon_decompress_block); \ + return Kernel("neon", &neon_decompress_block) #define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ case N: \ - if (sign_values) return Kernel("neon", &arm64::neon::neon_decompress_blocks); \ - return Kernel("neon", &arm64::neon::neon_decompress_blocks) + if (sign_values) return Kernel("neon", &neon_decompress_blocks); \ + return Kernel("neon", &neon_decompress_blocks) #define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ case BS: \ @@ -146,47 +149,51 @@ case N: \ } \ break -Kernel select_neon_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { +Kernel select_neon_decompress_block_f32(const u8 bit_width, const u32 block_size, bool sign_values) { switch (block_size) { PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); - default: return {"neon", nullptr}; + default: + return {"neon", nullptr}; } } -Kernel select_neon_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { +Kernel select_neon_decompress_blocks_f32(const u8 bit_width, const u32 block_size, bool sign_values) { switch (block_size) { PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); - default: return {"neon", nullptr}; + default: + return {"neon", nullptr}; } } -Kernel select_neon_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { +Kernel select_neon_decompress_block_f64(const u8 bit_width, const u32 block_size, bool sign_values) { switch (block_size) { PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); - default: return {"neon", nullptr}; + default: + return {"neon", nullptr}; } } -Kernel select_neon_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { +Kernel select_neon_decompress_blocks_f64(const u8 bit_width, const u32 block_size, bool sign_values) { switch (block_size) { PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); - default: return {"neon", nullptr}; + default: + return {"neon", nullptr}; } } diff --git a/src/arm64/sve2/compression.cpp b/src/arm64/sve2/compression.cpp index c6d8dd0..1839f12 100644 --- a/src/arm64/sve2/compression.cpp +++ b/src/arm64/sve2/compression.cpp @@ -2,27 +2,27 @@ #include namespace pernix::internal { -Kernel select_sve2_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"sve2", nullptr}; -} + Kernel select_sve2_compress_block_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } -Kernel select_sve2_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"sve2", nullptr}; -} + Kernel select_sve2_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } -Kernel select_sve2_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"sve2", nullptr}; -} + Kernel select_sve2_compress_block_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } -Kernel select_sve2_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { - (void)bit_width; - (void)block_size; - return {"sve2", nullptr}; -} + Kernel select_sve2_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } } diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp index ba796e0..8e65f6e 100644 --- a/src/arm64/sve2/decompression.cpp +++ b/src/arm64/sve2/decompression.cpp @@ -1,26 +1,29 @@ #include #include +using pernix::arm64::sve2::sve2_decompress_block; +using pernix::arm64::sve2::sve2_decompress_blocks; + namespace pernix::internal { #define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ case N: \ - if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_block); \ - return Kernel("sve2", &arm64::sve2::sve2_decompress_block) + if (sign_values) return Kernel("sve2", &sve2_decompress_block); \ + return Kernel("sve2", &sve2_decompress_block) #define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ case N: \ - if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks); \ - return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks) + if (sign_values) return Kernel("sve2", &sve2_decompress_blocks); \ + return Kernel("sve2", &sve2_decompress_blocks) #define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ case N: \ - if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_block); \ - return Kernel("sve2", &arm64::sve2::sve2_decompress_block) + if (sign_values) return Kernel("sve2", &sve2_decompress_block); \ + return Kernel("sve2", &sve2_decompress_block) #define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ case N: \ - if (sign_values) return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks); \ - return Kernel("sve2", &arm64::sve2::sve2_decompress_blocks) + if (sign_values) return Kernel("sve2", &sve2_decompress_blocks); \ + return Kernel("sve2", &sve2_decompress_blocks) #define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ case BS: \ @@ -146,49 +149,53 @@ case N: \ } \ break -Kernel select_sve2_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); - default: return {"sve2", nullptr}; + Kernel select_sve2_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"sve2", nullptr}; + } } -} -Kernel select_sve2_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); - default: return {"sve2", nullptr}; + Kernel select_sve2_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"sve2", nullptr}; + } } -} -Kernel select_sve2_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); - default: return {"sve2", nullptr}; + Kernel select_sve2_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"sve2", nullptr}; + } } -} -Kernel select_sve2_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); - default: return {"sve2", nullptr}; + Kernel select_sve2_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"sve2", nullptr}; + } } -} #undef PERNIX_CASE_DECOMPRESS_BLOCK_32 #undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 diff --git a/src/dispatch/cpu_features_arm.cpp b/src/dispatch/cpu_features_arm.cpp index 1b4ac78..525962b 100644 --- a/src/dispatch/cpu_features_arm.cpp +++ b/src/dispatch/cpu_features_arm.cpp @@ -1,5 +1,15 @@ #include +#if defined(__linux__) && (defined(__aarch64__) || defined(_M_ARM64)) +#include +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif +#ifndef HWCAP2_SVE2 +#define HWCAP2_SVE2 (1 << 1) +#endif +#endif + namespace pernix::internal { CpuFeatures detect_cpu_features() { CpuFeatures features{}; @@ -11,11 +21,12 @@ CpuFeatures detect_cpu_features() { features.neon = true; #endif - // sve -#if defined(__aarch64__) || defined(_M_ARM64) -#ifdef __ARM_FEATURE_SVE - features.sve = true; -#endif + // sve / sve2 — runtime detection via getauxval on Linux +#if defined(__linux__) && (defined(__aarch64__) || defined(_M_ARM64)) + unsigned long hwcap = getauxval(AT_HWCAP); + unsigned long hwcap2 = getauxval(AT_HWCAP2); + features.sve = (hwcap & HWCAP_SVE) != 0; + features.sve2 = (hwcap2 & HWCAP2_SVE2) != 0; #endif return features; diff --git a/src/dispatch/cpu_features_x86.cpp b/src/dispatch/cpu_features_x86.cpp index f43245c..973df4d 100644 --- a/src/dispatch/cpu_features_x86.cpp +++ b/src/dispatch/cpu_features_x86.cpp @@ -1,6 +1,6 @@ #include -#include +#include #if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) #if defined(_MSC_VER) @@ -19,7 +19,7 @@ void cpuid(int out[4], int leaf, int subleaf) { __cpuidex(out, leaf, subleaf); } -std::uint64_t xgetbv(unsigned int index) { +u64 xgetbv(unsigned int index) { return _xgetbv(index); } @@ -29,7 +29,7 @@ void cpuid(int out[4], int leaf, int subleaf) { __cpuid_count(leaf, subleaf, out[0], out[1], out[2], out[3]); } -std::uint64_t xgetbv(unsigned int index) { +u64 xgetbv(unsigned int index) { return _xgetbv(index); } @@ -54,7 +54,7 @@ CpuFeatures detect_cpu_features() { return features; } - const std::uint64_t xcr0 = xgetbv(0); + const u64 xcr0 = xgetbv(0); const bool xmm_enabled = (xcr0 & 0x2) != 0; const bool ymm_enabled = (xcr0 & 0x4) != 0; diff --git a/src/dispatch/select.cpp b/src/dispatch/select.cpp index b619af1..ca4673e 100644 --- a/src/dispatch/select.cpp +++ b/src/dispatch/select.cpp @@ -2,683 +2,719 @@ #include namespace pernix::internal { -Kernel select_compress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size) { - switch (backend) { - case Backend::Auto: - return select_auto_compress_block_f32(bit_width, block_size); - case Backend::Fallback: - return select_fallback_compress_block_f32(bit_width, block_size); + Kernel select_compress_block_f32(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_block_f32(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_block_f32(bit_width, block_size); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_compress_block_f32(bit_width, block_size); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_block_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_compress_block_f32(bit_width, block_size); + case Backend::X86Avx2: + return select_avx2_compress_block_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_compress_block_f32(bit_width, block_size); + case Backend::X86Bmi2: + return select_bmi2_compress_block_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_compress_block_f32(bit_width, block_size); + case Backend::Arm64Neon: + return select_neon_compress_block_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_compress_block_f32(bit_width, block_size); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_block_f32(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_compress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size) { - switch (backend) { - case Backend::Auto: - return select_auto_compress_blocks_f32(bit_width, block_size); - case Backend::Fallback: - return select_fallback_compress_blocks_f32(bit_width, block_size); + Kernel select_compress_blocks_f32(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_blocks_f32(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_blocks_f32(bit_width, block_size); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_compress_blocks_f32(bit_width, block_size); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_blocks_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_compress_blocks_f32(bit_width, block_size); + case Backend::X86Avx2: + return select_avx2_compress_blocks_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_compress_blocks_f32(bit_width, block_size); + case Backend::X86Bmi2: + return select_bmi2_compress_blocks_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_compress_blocks_f32(bit_width, block_size); + case Backend::Arm64Neon: + return select_neon_compress_blocks_f32(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_compress_blocks_f32(bit_width, block_size); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_blocks_f32(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_compress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size) { - switch (backend) { - case Backend::Auto: - return select_auto_compress_block_f64(bit_width, block_size); - case Backend::Fallback: - return select_fallback_compress_block_f64(bit_width, block_size); + Kernel select_compress_block_f64(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_block_f64(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_block_f64(bit_width, block_size); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_compress_block_f64(bit_width, block_size); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_block_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_compress_block_f64(bit_width, block_size); + case Backend::X86Avx2: + return select_avx2_compress_block_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_compress_block_f64(bit_width, block_size); + case Backend::X86Bmi2: + return select_bmi2_compress_block_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_compress_block_f64(bit_width, block_size); + case Backend::Arm64Neon: + return select_neon_compress_block_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_compress_block_f64(bit_width, block_size); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_block_f64(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_compress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size) { - switch (backend) { - case Backend::Auto: - return select_auto_compress_blocks_f64(bit_width, block_size); - case Backend::Fallback: - return select_fallback_compress_blocks_f64(bit_width, block_size); + Kernel select_compress_blocks_f64(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_blocks_f64(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_blocks_f64(bit_width, block_size); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_compress_blocks_f64(bit_width, block_size); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_blocks_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_compress_blocks_f64(bit_width, block_size); + case Backend::X86Avx2: + return select_avx2_compress_blocks_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_compress_blocks_f64(bit_width, block_size); + case Backend::X86Bmi2: + return select_bmi2_compress_blocks_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_compress_blocks_f64(bit_width, block_size); + case Backend::Arm64Neon: + return select_neon_compress_blocks_f64(bit_width, block_size); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_compress_blocks_f64(bit_width, block_size); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_blocks_f64(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_decompress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { - switch (backend) { - case Backend::Auto: - return select_auto_decompress_block_f32(bit_width, block_size, sign_values); - case Backend::Fallback: - return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); + Kernel select_decompress_block_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::X86Avx2: + return select_avx2_decompress_block_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::X86Bmi2: + return select_bmi2_decompress_block_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::Arm64Neon: + return select_neon_decompress_block_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_block_f32(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_decompress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { - switch (backend) { - case Backend::Auto: - return select_auto_decompress_blocks_f32(bit_width, block_size, sign_values); - case Backend::Fallback: - return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); + Kernel select_decompress_blocks_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::X86Avx2: + return select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::X86Bmi2: + return select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::Arm64Neon: + return select_neon_decompress_blocks_f32(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_decompress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { - switch (backend) { - case Backend::Auto: - return select_auto_decompress_block_f64(bit_width, block_size, sign_values); - case Backend::Fallback: - return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); + Kernel select_decompress_block_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::X86Avx2: + return select_avx2_decompress_block_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::X86Bmi2: + return select_bmi2_decompress_block_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::Arm64Neon: + return select_neon_decompress_block_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_block_f64(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_decompress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values) { - switch (backend) { - case Backend::Auto: - return select_auto_decompress_blocks_f64(bit_width, block_size, sign_values); - case Backend::Fallback: - return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); + Kernel select_decompress_blocks_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - case Backend::X86Avx512Vbmi: - return select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_AVX2) - case Backend::X86Avx2: - return select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::X86Avx2: + return select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_X86_BMI2) - case Backend::X86Bmi2: - return select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::X86Bmi2: + return select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_NEON) - case Backend::Arm64Neon: - return select_neon_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::Arm64Neon: + return select_neon_decompress_blocks_f64(bit_width, block_size, sign_values); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - case Backend::Arm64Sve: - return select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } #endif - default: - return {"invalid_backend", nullptr}; + default: + return {"invalid_backend", nullptr}; + } } -} -Kernel select_auto_compress_block_f32(uint8_t bit_width, uint32_t block_size) { + Kernel select_auto_compress_block_f32(u8 bit_width, u32 block_size) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_compress_block_f32(bit_width, block_size)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_block_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_compress_block_f32(bit_width, block_size)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_compress_block_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_compress_block_f32(bit_width, block_size)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_block_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_compress_block_f32(bit_width, block_size)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_compress_block_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_compress_block_f32(bit_width, block_size)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_compress_block_f32(bit_width, block_size)) { + return kernel; + } } - } #endif - return select_fallback_compress_block_f32(bit_width, block_size); -} + return select_fallback_compress_block_f32(bit_width, block_size); + } -Kernel select_auto_compress_blocks_f32(uint8_t bit_width, uint32_t block_size) { + Kernel select_auto_compress_blocks_f32(u8 bit_width, u32 block_size) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_compress_blocks_f32(bit_width, block_size)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_compress_blocks_f32(bit_width, block_size)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_compress_blocks_f32(bit_width, block_size)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_compress_blocks_f32(bit_width, block_size)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_compress_blocks_f32(bit_width, block_size)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } } - } #endif - return select_fallback_compress_blocks_f32(bit_width, block_size); -} + return select_fallback_compress_blocks_f32(bit_width, block_size); + } -Kernel select_auto_compress_block_f64(uint8_t bit_width, uint32_t block_size) { + Kernel select_auto_compress_block_f64(u8 bit_width, u32 block_size) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_compress_block_f64(bit_width, block_size)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_block_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_compress_block_f64(bit_width, block_size)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_compress_block_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_compress_block_f64(bit_width, block_size)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_block_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_compress_block_f64(bit_width, block_size)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_compress_block_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_compress_block_f64(bit_width, block_size)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_compress_block_f64(bit_width, block_size)) { + return kernel; + } } - } #endif - return select_fallback_compress_block_f64(bit_width, block_size); -} + return select_fallback_compress_block_f64(bit_width, block_size); + } -Kernel select_auto_compress_blocks_f64(uint8_t bit_width, uint32_t block_size) { + Kernel select_auto_compress_blocks_f64(u8 bit_width, u32 block_size) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_compress_blocks_f64(bit_width, block_size)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_compress_blocks_f64(bit_width, block_size)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_compress_blocks_f64(bit_width, block_size)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_compress_blocks_f64(bit_width, block_size)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_compress_blocks_f64(bit_width, block_size)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } } - } #endif - return select_fallback_compress_blocks_f64(bit_width, block_size); -} + return select_fallback_compress_blocks_f64(bit_width, block_size); + } -Kernel select_auto_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values) { + Kernel select_auto_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_decompress_block_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_decompress_block_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_decompress_block_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_decompress_block_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif - return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); -} + return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); + } -Kernel select_auto_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values) { + Kernel select_auto_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_decompress_blocks_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif - return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); -} + return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); + } -Kernel select_auto_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values) { + Kernel select_auto_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_decompress_block_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_decompress_block_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_decompress_block_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_decompress_block_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif - return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); -} + return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); + } -Kernel select_auto_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values) { + Kernel select_auto_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values) { #if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) - const CpuFeatures features = get_cached_cpu_features(); + const CpuFeatures features = get_cached_cpu_features(); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) - if ( - features.avx512f && - features.avx512dq && - features.avx512bw && - features.avx512vl && - features.avx512vbmi - ) { - if (auto kernel = select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values)) { - return kernel; + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_AVX2) - if (features.avx2) { - if (auto kernel = select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.avx2) { + if (auto kernel = select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_X86_BMI2) - if (features.bmi2) { - if (auto kernel = select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_NEON) - if (features.neon) { - if (auto kernel = select_neon_decompress_blocks_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.neon) { + if (auto kernel = select_neon_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif #if defined(PERNIX_BUILD_ARM64_SVE2) - if (features.sve) { - if (auto kernel = select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values)) { - return kernel; + if (features.sve2) { + if (auto kernel = select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } } - } #endif - return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); -} + return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); + } } diff --git a/src/fallback/fallback_compression.cpp b/src/fallback/fallback_compression.cpp index eeaa34d..4f1e955 100644 --- a/src/fallback/fallback_compression.cpp +++ b/src/fallback/fallback_compression.cpp @@ -135,53 +135,53 @@ case N: return Kernel("fallback", &compress_blocks_fallback default: return {"fallback", nullptr}; \ } -Kernel select_fallback_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); - default: - return {"fallback", nullptr}; + Kernel select_fallback_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"fallback", nullptr}; + } } -} -Kernel select_fallback_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); - default: - return {"fallback", nullptr}; + Kernel select_fallback_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"fallback", nullptr}; + } } -} -Kernel select_fallback_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); - default: - return {"fallback", nullptr}; + Kernel select_fallback_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"fallback", nullptr}; + } } -} -Kernel select_fallback_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); - default: - return {"fallback", nullptr}; + Kernel select_fallback_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"fallback", nullptr}; + } } -} #undef PERNIX_CASE_COMPRESS_BLOCK_32 #undef PERNIX_CASE_COMPRESS_BLOCKS_32 diff --git a/src/fallback/fallback_decompression.cpp b/src/fallback/fallback_decompression.cpp index 232b831..184397f 100644 --- a/src/fallback/fallback_decompression.cpp +++ b/src/fallback/fallback_decompression.cpp @@ -146,49 +146,53 @@ case N: \ } \ break -Kernel select_fallback_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); - default: return {"fallback", nullptr}; + Kernel select_fallback_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"fallback", nullptr}; + } } -} -Kernel select_fallback_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); - default: return {"fallback", nullptr}; + Kernel select_fallback_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"fallback", nullptr}; + } } -} -Kernel select_fallback_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); - default: return {"fallback", nullptr}; + Kernel select_fallback_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"fallback", nullptr}; + } } -} -Kernel select_fallback_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); - default: return {"fallback", nullptr}; + Kernel select_fallback_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"fallback", nullptr}; + } } -} #undef PERNIX_CASE_DECOMPRESS_BLOCK_32 #undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 diff --git a/src/internal/pernix/arm64/neon/common.h b/src/internal/pernix/arm64/neon/common.h index 2677416..6cd9ad0 100644 --- a/src/internal/pernix/arm64/neon/common.h +++ b/src/internal/pernix/arm64/neon/common.h @@ -10,9 +10,9 @@ namespace pernix::arm64::neon::internal { float64x2_t val[8]; }; - static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + static constexpr u32 tail_bytes(const u8 bit_width, const u32 remaining_elements) { + const u32 tail_bits = remaining_elements * bit_width; + const u32 tail_bytes = (tail_bits + 7u) / 8u; return tail_bytes; } @@ -110,110 +110,110 @@ __always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t &input }; } -__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t *input, const uint32_t tail_bytes_count) { - uint8_t buffer[16] = {0}; +__always_inline uint8x16_t neon_load_tail_elements_int8(const u8 *input, const u32 tail_bytes_count) { + u8 buffer[16] = {0}; std::memcpy(buffer, input, tail_bytes_count); return vld1q_u8(buffer); } -__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t *input, const uint32_t tail_bytes_count) { - uint16_t buffer[8] = {0}; +__always_inline uint16x8_t neon_load_tail_elements_int16(const u8 *input, const u32 tail_bytes_count) { + u16 buffer[8] = {0}; std::memcpy(buffer, input, tail_bytes_count); return vld1q_u16(buffer); } -__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t *input, const uint32_t tail_bytes_count) { - uint32_t buffer[4] = {0}; +__always_inline uint32x4_t neon_load_tail_elements_int32(const u8 *input, const u32 tail_bytes_count) { + u32 buffer[4] = {0}; std::memcpy(buffer, input, tail_bytes_count); return vld1q_u32(buffer); } -__always_inline float32x4_t neon_load_tail_elements_f32(const uint8_t *input, const uint32_t tail_elements) { +__always_inline float32x4_t neon_load_tail_elements_f32(const u8 *input, const u32 tail_elements) { float32_t buffer[4] = {0.0f}; std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); return vld1q_f32(buffer); } -__always_inline float64x2_t neon_load_tail_elements_f64(const uint8_t *input, const uint32_t tail_elements) { +__always_inline float64x2_t neon_load_tail_elements_f64(const u8 *input, const u32 tail_elements) { float64_t buffer[2] = {0.0}; std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); return vld1q_f64(buffer); } -__always_inline void neon_store_tail_elements_int8(uint8_t *output, const uint8x16x4_t &data, - const uint32_t tail_elements) { - uint8_t buffer[16 * 4]; - for (uint32_t i = 0; i < 4; ++i) { +__always_inline void neon_store_tail_elements_int8(u8 *output, const uint8x16x4_t &data, + const u32 tail_elements) { + u8 buffer[16 * 4]; + for (u32 i = 0; i < 4; ++i) { vst1q_u8(buffer + i * 16, data.val[i]); } - std::memcpy(output, buffer, tail_elements * sizeof(uint8_t)); + std::memcpy(output, buffer, tail_elements * sizeof(u8)); } -__always_inline void neon_store_tail_elements_int16(uint16_t *output, const uint16x8x4_t &data, - const uint32_t tail_elements) { - uint16_t buffer[8 * 4]; - for (uint32_t i = 0; i < 4; ++i) { +__always_inline void neon_store_tail_elements_int16(u16 *output, const uint16x8x4_t &data, + const u32 tail_elements) { + u16 buffer[8 * 4]; + for (u32 i = 0; i < 4; ++i) { vst1q_u16(buffer + i * 8, data.val[i]); } - std::memcpy(output, buffer, tail_elements * sizeof(uint16_t)); + std::memcpy(output, buffer, tail_elements * sizeof(u16)); } -__always_inline void neon_store_tail_elements_int32(uint32_t *output, const uint32x4x4_t &data, - const uint32_t tail_elements) { - uint32_t buffer[4 * 4]; - for (uint32_t i = 0; i < 4; ++i) { +__always_inline void neon_store_tail_elements_int32(u32 *output, const uint32x4x4_t &data, + const u32 tail_elements) { + u32 buffer[4 * 4]; + for (u32 i = 0; i < 4; ++i) { vst1q_u32(buffer + i * 4, data.val[i]); } - std::memcpy(output, buffer, tail_elements * sizeof(uint32_t)); + std::memcpy(output, buffer, tail_elements * sizeof(u32)); } __always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4x4_t &data, - const uint32_t tail_elements) { + const u32 tail_elements) { float32_t buffer[16 * 4]; - for (uint32_t i = 0; i < 4; ++i) { + for (u32 i = 0; i < 4; ++i) { vst1q_f32(buffer + i * 4, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); } __always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4x2_t &data, - const uint32_t tail_elements) { + const u32 tail_elements) { float32_t buffer[8 * 2]; - for (uint32_t i = 0; i < 2; ++i) { + for (u32 i = 0; i < 2; ++i) { vst1q_f32(buffer + i * 4, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); } __always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4_t &data, - const uint32_t tail_elements) { + const u32 tail_elements) { float32_t buffer[4]; vst1q_f32(buffer, data); std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); } __always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x4_t &data, - const uint32_t tail_elements) { + const u32 tail_elements) { float64_t buffer[2 * 4]; - for (uint32_t i = 0; i < 4; ++i) { + for (u32 i = 0; i < 4; ++i) { vst1q_f64(buffer + i * 2, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); } __always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x2_t &data, - const uint32_t tail_elements) { + const u32 tail_elements) { float64_t buffer[2 * 2]; - for (uint32_t i = 0; i < 2; ++i) { + for (u32 i = 0; i < 2; ++i) { vst1q_f64(buffer + i * 2, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); } __always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x8_t &data, - const uint32_t tail_elements) { + const u32 tail_elements) { float64_t buffer[2 * 8]; - for (uint32_t i = 0; i < 8; ++i) { + for (u32 i = 0; i < 8; ++i) { vst1q_f64(buffer + i * 2, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); diff --git a/src/internal/pernix/arm64/neon/compression.h b/src/internal/pernix/arm64/neon/compression.h index f88fbb6..83d46c3 100644 --- a/src/internal/pernix/arm64/neon/compression.h +++ b/src/internal/pernix/arm64/neon/compression.h @@ -8,114 +8,114 @@ #include namespace pernix::arm64::neon { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; -} -} // namespace internal - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::neon_compress_block_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::neon_compress_block_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::neon_compress_block_17to24(input, scale, output); + namespace internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + } // namespace internal + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; } - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::neon_compress_block_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::neon_compress_block_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::neon_compress_block_17to24(input, scale, output); - } - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - neon_compress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; } - return 0; -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int neon_compress_blocks(const u8 * __restrict__ input, const f32 scale, f32 * __restrict__ output, + const u32 blocks) { + const u8 *block_input = input; + f32 *block_output = output; -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; + for (u32 block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } - for (uint32_t block = 0; block < blocks; ++block) { - neon_compress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int neon_compress_blocks(const u8 * __restrict__ input, const f64 scale, f64 * __restrict__ output, + const u32 blocks) { + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } - return 0; -} } // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/src/internal/pernix/arm64/neon/tables.h b/src/internal/pernix/arm64/neon/tables.h index d085551..beb7c47 100644 --- a/src/internal/pernix/arm64/neon/tables.h +++ b/src/internal/pernix/arm64/neon/tables.h @@ -7,206 +7,208 @@ #include namespace pernix::arm64::neon::internal { -namespace detail { -inline constexpr std::size_t neon_vector_width = 128; -inline constexpr uint8_t inactive_lane = 0xff; - -template -constexpr bool table_indices_are_valid(const std::array& table) { - return std::ranges::all_of(table, [](const uint8_t index) { - return index == inactive_lane || index < Elements; - }); -} - -template -constexpr std::array make_primary_permute() { - static_assert(LANE_BITS % 8 == 0); - - constexpr std::size_t lane_bytes = LANE_BITS / 8; - static_assert(ELEMENTS % lane_bytes == 0); - - std::array table{}; - table.fill(inactive_lane); - - for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { - const std::size_t bit_start = entry * BIT_WIDTH; - const std::size_t first_byte = bit_start / 8; - const std::size_t base = entry * lane_bytes; - - for (std::size_t lane_byte = 0; lane_byte < lane_bytes; ++lane_byte) { - table[base + lane_byte] = static_cast(first_byte + lane_byte); + namespace detail { + inline constexpr std::size_t neon_vector_width = 128; + inline constexpr u8 inactive_lane = 0xff; + + template + constexpr bool table_indices_are_valid(const std::array &table) { + return std::ranges::all_of(table, [](const u8 index) { + return index == inactive_lane || index < Elements; + }); } - } - return table; -} + template + constexpr std::array make_primary_permute() { + static_assert(LANE_BITS % 8 == 0); -template -constexpr std::array make_spill_permute() { - static_assert(LANE_BITS % 8 == 0); + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); - constexpr std::size_t lane_bytes = LANE_BITS / 8; - static_assert(ELEMENTS % lane_bytes == 0); + std::array table{}; + table.fill(inactive_lane); - std::array table{}; - table.fill(inactive_lane); + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t base = entry * lane_bytes; - for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { - const std::size_t bit_start = entry * BIT_WIDTH; - const std::size_t first_byte = bit_start / 8; - const std::size_t bit_offset = bit_start % 8; - const std::size_t base = entry * lane_bytes; + for (std::size_t lane_byte = 0; lane_byte < lane_bytes; ++lane_byte) { + table[base + lane_byte] = static_cast(first_byte + lane_byte); + } + } - if (bit_offset + BIT_WIDTH > LANE_BITS) { - table[base] = static_cast(first_byte + lane_bytes); + return table; } - } - return table; -} + template + constexpr std::array make_spill_permute() { + static_assert(LANE_BITS % 8 == 0); -template -constexpr std::array make_shift_right() { - std::array table{}; - table.fill(0); + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); - for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { - const std::size_t bit_start = entry * BIT_WIDTH; - const std::size_t bit_offset = bit_start % 8u; + std::array table{}; + table.fill(inactive_lane); - table[entry] = -static_cast(bit_offset); - } + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t base = entry * lane_bytes; - return table; -} + if (bit_offset + BIT_WIDTH > LANE_BITS) { + table[base] = static_cast(first_byte + lane_bytes); + } + } -template -constexpr std::array make_shift_left_for_spill() { - std::array table{}; - table.fill(0); + return table; + } + + template + constexpr std::array make_shift_right() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + + table[entry] = -static_cast(bit_offset); + } + + return table; + } + + template + constexpr std::array make_shift_left_for_spill() { + std::array table{}; + table.fill(0); - for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { - const std::size_t bit_start = entry * BIT_WIDTH; - const std::size_t bit_offset = bit_start % 8u; - const bool spills = bit_offset + BIT_WIDTH > LANE_BITS; + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + const bool spills = bit_offset + BIT_WIDTH > LANE_BITS; - table[entry] = spills ? static_cast(LANE_BITS - bit_offset) : 0; - } + table[entry] = spills ? static_cast(LANE_BITS - bit_offset) : 0; + } - return table; -} + return table; + } + + template + constexpr std::array make_contiguous_permute_32() { + static_assert(ELEMENTS % 4 == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / 4; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + const std::size_t bit_end = bit_start + BIT_WIDTH - 1; + const std::size_t first_byte = bit_start / 8; + const std::size_t last_byte = bit_end / 8; + const std::size_t base = entry * 4; + + for (std::size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); + } + } + + return table; + } -template -constexpr std::array make_contiguous_permute_32() { - static_assert(ELEMENTS % 4 == 0); + template + constexpr std::array make_shift_right_32() { + std::array table{}; + table.fill(0); - std::array table{}; - table.fill(inactive_lane); + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; - for (std::size_t entry = 0; entry < ELEMENTS / 4; ++entry) { - const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; - const std::size_t bit_end = bit_start + BIT_WIDTH - 1; - const std::size_t first_byte = bit_start / 8; - const std::size_t last_byte = bit_end / 8; - const std::size_t base = entry * 4; + table[entry] = -static_cast(bit_start % 8u); + } - for (std::size_t byte = first_byte; byte <= last_byte; ++byte) { - table[base + (byte - first_byte)] = static_cast(byte); + return table; } - } - - return table; -} - -template -constexpr std::array make_shift_right_32() { - std::array table{}; - table.fill(0); - - for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { - const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; - - table[entry] = -static_cast(bit_start % 8u); - } - - return table; -} -} // namespace detail - -template -struct table_unpacking; - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && VECTOR_WIDTH == detail::neon_vector_width) -struct table_unpacking { -private: - static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; - static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 8; - -public: - static constexpr uint8_t bit_width = BIT_WIDTH; - - alignas(64) static constexpr std::array permute1 = - detail::make_primary_permute(); - alignas(64) static constexpr std::array permute2 = - detail::make_spill_permute(); - alignas(64) static constexpr std::array shift1 = detail::make_shift_right(); - alignas(64) static constexpr std::array shift2 = - detail::make_shift_left_for_spill(); - - static_assert(PERMUTE_ELEMENTS == 16); - static_assert(SHIFT_ELEMENTS == 16); - static_assert(detail::table_indices_are_valid(permute1)); - static_assert(detail::table_indices_are_valid(permute2)); -}; - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && VECTOR_WIDTH == detail::neon_vector_width) -struct table_unpacking { -private: - static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; - static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 16; - -public: - static constexpr uint8_t bit_width = BIT_WIDTH; - - alignas(64) static constexpr std::array permute1 = - detail::make_primary_permute(); - alignas(64) static constexpr std::array permute2 = - detail::make_spill_permute(); - alignas(64) static constexpr std::array shift1 = - detail::make_shift_right(); - alignas(64) static constexpr std::array shift2 = - detail::make_shift_left_for_spill(); - - static_assert(PERMUTE_ELEMENTS == 16); - static_assert(SHIFT_ELEMENTS == 8); - static_assert(detail::table_indices_are_valid(permute1)); - static_assert(detail::table_indices_are_valid(permute2)); -}; - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && VECTOR_WIDTH == detail::neon_vector_width && START_BIT_OFFSET < 8) -struct table_unpacking { -private: - static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; - static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 32; - -public: - static constexpr uint8_t bit_width = BIT_WIDTH; - - alignas(64) static constexpr std::array permute = - detail::make_contiguous_permute_32(); - alignas(64) static constexpr std::array shift = - detail::make_shift_right_32(); - - static_assert(PERMUTE_ELEMENTS == 16); - static_assert(SHIFT_ELEMENTS == 4); - static_assert(detail::table_indices_are_valid(permute)); -}; - -template -struct table_packing; + } // namespace detail + + template + struct table_unpacking; + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && VECTOR_WIDTH == detail::neon_vector_width) + struct table_unpacking { + private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 8; + + public: + static constexpr u8 bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 16); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); + }; + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && VECTOR_WIDTH == detail::neon_vector_width) + struct table_unpacking { + private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 16; + + public: + static constexpr u8 bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = + detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 8); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); + }; + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && VECTOR_WIDTH == detail::neon_vector_width && START_BIT_OFFSET < + 8) + struct table_unpacking { + private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 32; + + public: + static constexpr u8 bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute = + detail::make_contiguous_permute_32(); + alignas(64) static constexpr std::array shift = + detail::make_shift_right_32(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 4); + static_assert(detail::table_indices_are_valid(permute)); + }; + + template + struct table_packing; } // namespace pernix::arm64::internal #endif // PERNIX_ARM64_NEON_TABLES_H diff --git a/src/internal/pernix/arm64/neon/unpacking.h b/src/internal/pernix/arm64/neon/unpacking.h index e70fbbc..cc9bcf0 100644 --- a/src/internal/pernix/arm64/neon/unpacking.h +++ b/src/internal/pernix/arm64/neon/unpacking.h @@ -7,7 +7,7 @@ using namespace pernix::arm64::neon::internal; namespace pernix::arm64::neon::internal::b128 { - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t &input) { if constexpr (BIT_WIDTH == 8) { @@ -43,7 +43,7 @@ __always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t &input) { } } - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t &input) { if constexpr (BIT_WIDTH == 16) { @@ -77,7 +77,7 @@ __always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t &input) { } } - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t &input) { using tables = table_unpacking; @@ -92,7 +92,7 @@ __always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t &input) { constexpr int sign_shift = 32 - BIT_WIDTH; return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); } else { - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + constexpr u32 mask = (u32{1} << BIT_WIDTH) - 1u; return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); } } diff --git a/src/internal/pernix/arm64/sve2/compression.h b/src/internal/pernix/arm64/sve2/compression.h index 72d9229..33f48fe 100644 --- a/src/internal/pernix/arm64/sve2/compression.h +++ b/src/internal/pernix/arm64/sve2/compression.h @@ -7,42 +7,42 @@ #include namespace pernix { -namespace internal { -template -inline constexpr bool sve2_compression_unimplemented_v = false; -} // namespace internal - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_compress_block(const float_t*, float_t, uint8_t*) { - static_assert(internal::sve2_compression_unimplemented_v, - "ARM64 SVE2 compression is not implemented yet"); - return -1; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_compress_block(const double_t*, double_t, uint8_t*) { - static_assert(internal::sve2_compression_unimplemented_v, - "ARM64 SVE2 compression is not implemented yet"); - return -1; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { - static_assert(internal::sve2_compression_unimplemented_v, - "ARM64 SVE2 compression is not implemented yet"); - return -1; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { - static_assert(internal::sve2_compression_unimplemented_v, - "ARM64 SVE2 compression is not implemented yet"); - return -1; -} + namespace internal { + template + inline constexpr bool sve2_compression_unimplemented_v = false; + } // namespace internal + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_block(const f32 *, f32, u8 *) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_block(const f64 *, f64, u8 *) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_blocks(const f32 *, f32, u8 *, u32) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_blocks(const f64 *, f64, u8 *, u32) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } } // namespace pernix #endif // PERNIX_ARM64_SVE2_COMPRESSION_H diff --git a/src/internal/pernix/arm64/sve2/packing.h b/src/internal/pernix/arm64/sve2/packing.h index 5cf2355..7d644f2 100644 --- a/src/internal/pernix/arm64/sve2/packing.h +++ b/src/internal/pernix/arm64/sve2/packing.h @@ -4,7 +4,7 @@ #include namespace pernix::arm64::sve2::internal { - template + template inline constexpr bool packing_unimplemented_v = false; } // namespace pernix::arm64::sve2::internal diff --git a/src/internal/pernix/arm64/sve2/tables.h b/src/internal/pernix/arm64/sve2/tables.h index 8ef61d8..a11fcde 100644 --- a/src/internal/pernix/arm64/sve2/tables.h +++ b/src/internal/pernix/arm64/sve2/tables.h @@ -3,117 +3,115 @@ #include -#include - namespace pernix::arm64::sve2::internal { - template - struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svbool_t pg_b8() { return svptrue_b8(); } - - static svbool_t pg_b16() { return svptrue_b16(); } - - static svbool_t pg_b32() { return svptrue_b32(); } - }; - - template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) - struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svuint8_t permute() { - const svbool_t pg = svptrue_b8(); - return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); - } - - static svuint8_t spill_permute() { - const svbool_t pg = svptrue_b8(); - return svadd_n_u8_x(pg, permute(), 1); - } - - static svuint8_t shift() { - const svbool_t pg = svptrue_b8(); - return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); +template +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svbool_t pg_b8() { return svptrue_b8(); } + + static svbool_t pg_b16() { return svptrue_b16(); } + + static svbool_t pg_b32() { return svptrue_b32(); } +}; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 1); + } + + static svuint8_t shift() { + const svbool_t pg = svptrue_b8(); + return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); + } + + static svuint8_t spill_shift() { + const svbool_t pg = svptrue_b8(); + return svsub_u8_x(pg, svdup_n_u8(8), shift()); + } +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); + const svuint8_t byte = svand_n_u8_x(pg, lane, 1); + + svuint8_t first; + if constexpr (BIT_WIDTH == 16) { + first = svlsl_n_u8_x(pg, elem, 1); + } else { + constexpr u8 extra_bits = BIT_WIDTH - 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); + first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); } - static svuint8_t spill_shift() { - const svbool_t pg = svptrue_b8(); - return svsub_u8_x(pg, svdup_n_u8(8), shift()); - } - }; - - template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) - struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svuint8_t permute() { - const svbool_t pg = svptrue_b8(); - const svuint8_t lane = svindex_u8(0, 1); - const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); - const svuint8_t byte = svand_n_u8_x(pg, lane, 1); - - svuint8_t first; - if constexpr (BIT_WIDTH == 16) { - first = svlsl_n_u8_x(pg, elem, 1); - } else { - constexpr uint8_t extra_bits = BIT_WIDTH - 8u; - const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); - const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); - first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); - } - - return svadd_u8_x(pg, first, byte); + return svadd_u8_x(pg, first, byte); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 2); + } + + static svuint16_t shift() { + const svbool_t pg = svptrue_b16(); + return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); + } + + static svuint16_t spill_shift() { + const svbool_t pg = svptrue_b16(); + const svuint16_t bit_shift = shift(); + const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); + return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); + } +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); + const svuint8_t byte = svand_n_u8_x(pg, lane, 3); + + svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); + if constexpr (BIT_WIDTH % 8u != 0) { + constexpr u8 extra_bits = BIT_WIDTH % 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low_bits = + svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); + first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); } - static svuint8_t spill_permute() { - const svbool_t pg = svptrue_b8(); - return svadd_n_u8_x(pg, permute(), 2); - } - - static svuint16_t shift() { - const svbool_t pg = svptrue_b16(); - return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); - } + return svadd_u8_x(pg, first, byte); + } - static svuint16_t spill_shift() { - const svbool_t pg = svptrue_b16(); - const svuint16_t bit_shift = shift(); - const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); - return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); - } - }; - - template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) - struct table_unpacking { - static constexpr uint8_t bit_width = BIT_WIDTH; - - static svuint8_t permute() { - const svbool_t pg = svptrue_b8(); - const svuint8_t lane = svindex_u8(0, 1); - const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); - const svuint8_t byte = svand_n_u8_x(pg, lane, 3); - - svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); - if constexpr (BIT_WIDTH % 8u != 0) { - constexpr uint8_t extra_bits = BIT_WIDTH % 8u; - const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); - const svuint8_t low_bits = - svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); - first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); - } - - return svadd_u8_x(pg, first, byte); - } - - static svuint32_t shift() { - const svbool_t pg = svptrue_b32(); - return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), - START_BIT_OFFSET), 7); - } - }; + static svuint32_t shift() { + const svbool_t pg = svptrue_b32(); + return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), + START_BIT_OFFSET), 7); + } +}; } // namespace pernix::arm64::sve2::internal #endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/src/internal/pernix/arm64/sve2/unpacking.h b/src/internal/pernix/arm64/sve2/unpacking.h index 8bb7e3c..3d0825a 100644 --- a/src/internal/pernix/arm64/sve2/unpacking.h +++ b/src/internal/pernix/arm64/sve2/unpacking.h @@ -6,89 +6,89 @@ #include "tables.h" namespace pernix::arm64::sve2::internal { - template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, const svuint8_t spill_permute, const svuint8_t spill_shift) { - if constexpr (BIT_WIDTH == 8) { - return svreinterpret_s8(input); - } else { - const svbool_t pg = svptrue_b8(); + if constexpr (BIT_WIDTH == 8) { + return svreinterpret_s8(input); + } else { + const svbool_t pg = svptrue_b8(); + + const svuint8_t permuted = svtbl_u8(input, permute); + svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); + const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); + unpacked = svorr_u8_x(pg, unpacked, spill_shifted); + } - const svuint8_t permuted = svtbl_u8(input, permute); - svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + if constexpr (BIT_WIDTH == 1) { + unpacked = svand_n_u8_x(pg, unpacked, 1); + return svreinterpret_s8(unpacked); + } else { + constexpr int sign_shift = 8 - BIT_WIDTH; - if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { - const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); - const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); - unpacked = svorr_u8_x(pg, unpacked, spill_shifted); - } + unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); - if constexpr (BIT_WIDTH == 1) { - unpacked = svand_n_u8_x(pg, unpacked, 1); - return svreinterpret_s8(unpacked); + if constexpr (SIGN_VALUES) { + return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); } else { - constexpr int sign_shift = 8 - BIT_WIDTH; - - unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); - - if constexpr (SIGN_VALUES) { - return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); - } else { - return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); - } + return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); } } } +} - template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, const svuint16_t shift, const svuint8_t spill_permute, const svuint16_t spill_shift) { - if constexpr (BIT_WIDTH == 16) { - return svreinterpret_s16(input); - } else { - const svbool_t pg = svptrue_b16(); - - const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); - svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); - const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), - spill_shift); - shifted = svorr_u16_x(pg, shifted, spill_shifted); - } - - constexpr int sign_shift = 16 - BIT_WIDTH; - shifted = svlsl_n_u16_x(pg, shifted, sign_shift); - - if constexpr (SIGN_VALUES) { - return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); - } else { - return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); - } + if constexpr (BIT_WIDTH == 16) { + return svreinterpret_s16(input); + } else { + const svbool_t pg = svptrue_b16(); + + const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); + svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); + const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), + spill_shift); + shifted = svorr_u16_x(pg, shifted, spill_shifted); } - } - template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { - using table = table_unpacking; - - const svbool_t pg = svptrue_b32(); - const svuint8_t permuted = svtbl_u8(input, table::permute()); - const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + constexpr int sign_shift = 16 - BIT_WIDTH; + shifted = svlsl_n_u16_x(pg, shifted, sign_shift); if constexpr (SIGN_VALUES) { - constexpr int sign_shift = 32 - BIT_WIDTH; - return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); } else { - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; - return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); } } +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { + using table = table_unpacking; + + const svbool_t pg = svptrue_b32(); + const svuint8_t permuted = svtbl_u8(input, table::permute()); + const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + } else { + constexpr u32 mask = (u32{1} << BIT_WIDTH) - 1u; + return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + } +} } // namespace pernix::arm64::sve2::internal #endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/src/internal/pernix/dispatch/kernel.h b/src/internal/pernix/dispatch/kernel.h index 1d51c64..b7c0618 100644 --- a/src/internal/pernix/dispatch/kernel.h +++ b/src/internal/pernix/dispatch/kernel.h @@ -1,14 +1,14 @@ #ifndef PERNIX_KERNEL_H #define PERNIX_KERNEL_H -#include +#include #include namespace pernix::internal { -using KernelBlockF32Func = int (*)(const void*, float, void*); -using KernelBlocksF32Func = int (*)(const void*, float, void*, unsigned int); -using KernelBlockF64Func = int (*)(const void*, double, void*); -using KernelBlocksF64Func = int (*)(const void*, double, void*, unsigned int); +using KernelBlockF32Func = i32 (*)(const void*, f32, void*); +using KernelBlocksF32Func = i32 (*)(const void*, f32, void*, u32); +using KernelBlockF64Func = i32 (*)(const void*, f64, void*); +using KernelBlocksF64Func = i32 (*)(const void*, f64, void*, u32); template struct Kernel { diff --git a/src/internal/pernix/dispatch/select.h b/src/internal/pernix/dispatch/select.h index 152117a..d00357c 100644 --- a/src/internal/pernix/dispatch/select.h +++ b/src/internal/pernix/dispatch/select.h @@ -5,153 +5,165 @@ #include namespace pernix::internal { -Kernel select_compress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size); + Kernel select_compress_block_f32(Backend backend, u8 bit_width, u32 block_size); -Kernel select_compress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size); + Kernel select_compress_blocks_f32(Backend backend, u8 bit_width, u32 block_size); -Kernel select_compress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size); + Kernel select_compress_block_f64(Backend backend, u8 bit_width, u32 block_size); -Kernel select_compress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size); + Kernel select_compress_blocks_f64(Backend backend, u8 bit_width, u32 block_size); -Kernel select_decompress_block_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_decompress_block_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_decompress_blocks_f32(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_decompress_blocks_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_decompress_block_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_decompress_block_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_decompress_blocks_f64(Backend backend, uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_decompress_blocks_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_auto_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_auto_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_auto_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_auto_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_auto_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_auto_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_auto_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_auto_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_auto_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_auto_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_auto_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_auto_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_auto_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_auto_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_auto_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_auto_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_fallback_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_fallback_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_fallback_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_fallback_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_fallback_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_fallback_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_fallback_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_fallback_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_fallback_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel + select_fallback_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_fallback_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_fallback_decompress_blocks_f32(u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_fallback_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel + select_fallback_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_fallback_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_fallback_decompress_blocks_f64(u8 bit_width, u32 block_size, + bool sign_values); #if defined(PERNIX_BUILD_X86_AVX2) -Kernel select_avx2_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_avx2_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_avx2_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_avx2_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_avx2_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_avx2_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_avx2_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_avx2_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_avx2_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx2_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_avx2_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx2_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_avx2_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx2_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_avx2_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx2_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); #endif #if defined(PERNIX_BUILD_X86_BMI2) -Kernel select_bmi2_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_bmi2_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_bmi2_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_bmi2_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_bmi2_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_bmi2_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_bmi2_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_bmi2_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_bmi2_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_bmi2_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_bmi2_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_bmi2_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_bmi2_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_bmi2_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_bmi2_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_bmi2_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); #endif #if defined(PERNIX_BUILD_X86_AVX512_VBMI) -Kernel select_avx512vbmi_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_avx512vbmi_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_avx512vbmi_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_avx512vbmi_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_avx512vbmi_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_avx512vbmi_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_avx512vbmi_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_avx512vbmi_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_avx512vbmi_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx512vbmi_decompress_block_f32(u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_avx512vbmi_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx512vbmi_decompress_blocks_f32(u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_avx512vbmi_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx512vbmi_decompress_block_f64(u8 bit_width, u32 block_size, + bool sign_values); -Kernel select_avx512vbmi_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_avx512vbmi_decompress_blocks_f64(u8 bit_width, u32 block_size, + bool sign_values); #endif #if defined(PERNIX_BUILD_ARM64_NEON) -Kernel select_neon_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_neon_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_neon_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_neon_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_neon_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_neon_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_neon_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_neon_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_neon_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_neon_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_neon_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_neon_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_neon_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_neon_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_neon_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_neon_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); #endif #if defined(PERNIX_BUILD_ARM64_SVE2) -Kernel select_sve2_compress_block_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_sve2_compress_block_f32(u8 bit_width, u32 block_size); -Kernel select_sve2_compress_blocks_f32(uint8_t bit_width, uint32_t block_size); + Kernel select_sve2_compress_blocks_f32(u8 bit_width, u32 block_size); -Kernel select_sve2_compress_block_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_sve2_compress_block_f64(u8 bit_width, u32 block_size); -Kernel select_sve2_compress_blocks_f64(uint8_t bit_width, uint32_t block_size); + Kernel select_sve2_compress_blocks_f64(u8 bit_width, u32 block_size); -Kernel select_sve2_decompress_block_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_sve2_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_sve2_decompress_blocks_f32(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_sve2_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_sve2_decompress_block_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_sve2_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); -Kernel select_sve2_decompress_blocks_f64(uint8_t bit_width, uint32_t block_size, bool sign_values); + Kernel select_sve2_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); #endif } diff --git a/src/internal/pernix/fallback/avx2_compression.h b/src/internal/pernix/fallback/avx2_compression.h index c7f03bc..e2c7200 100644 --- a/src/internal/pernix/fallback/avx2_compression.h +++ b/src/internal/pernix/fallback/avx2_compression.h @@ -11,69 +11,69 @@ #include namespace pernix { -namespace internal { -/** - * @brief Quantize a single float value to int32_t using the provided scale. + namespace internal { + /** + * @brief Quantize a single float value to i32 using the provided scale. * * @param input input float value to be quantized. * @param scale scaling factor used during quantization. - * @return int32_t quantized integer value. + * @return i32 quantized integer value. */ -__always_inline int32_t quantize_ps_epi32(const float input, const float scale) { - return static_cast(std::lroundf(input * scale)); -} +__always_inline i32 quantize_ps_epi32(const float input, const float scale) { + return static_cast(std::lroundf(input * scale)); + } -/** - * @brief Quantize a single double value to int64_t using the provided scale. + /** + * @brief Quantize a single double value to i64 using the provided scale. * * @param input input double value to be quantized. * @param scale scaling factor used during quantization. - * @return int64_t quantized integer value. + * @return i64 quantized integer value. */ -__always_inline int64_t quantize_pd_epi64(const double_t input, const double_t scale) { - return std::llround(input * scale); -} +__always_inline i64 quantize_pd_epi64(const f64 input, const f64 scale) { + return std::llround(input * scale); + } -/** + /** * @brief Quantize and clamp without narrowing through an out-of-range integer type. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_floating_point_v) -__always_inline int32_t quantize_clamped(const T input, const T scale) { - constexpr int64_t min_value = BIT_WIDTH == 1 ? 0 : -(int64_t{1} << (BIT_WIDTH - 1)); - constexpr int64_t max_value = BIT_WIDTH == 1 ? 1 : ((int64_t{1} << (BIT_WIDTH - 1)) - 1); - - const long double scaled = static_cast(input) * static_cast(scale); - if (std::isnan(scaled)) { - return 0; - } - if (scaled <= static_cast(min_value)) { - return static_cast(min_value); - } - if (scaled >= static_cast(max_value)) { - return static_cast(max_value); - } - - return static_cast(std::llround(scaled)); -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_floating_point_v) +__always_inline i32 quantize_clamped(const T input, const T scale) { + constexpr i64 min_value = BIT_WIDTH == 1 ? 0 : -(i64{1} << (BIT_WIDTH - 1)); + constexpr i64 max_value = BIT_WIDTH == 1 ? 1 : ((i64{1} << (BIT_WIDTH - 1)) - 1); + + const long double scaled = static_cast(input) * static_cast(scale); + if (std::isnan(scaled)) { + return 0; + } + if (scaled <= static_cast(min_value)) { + return static_cast(min_value); + } + if (scaled >= static_cast(max_value)) { + return static_cast(max_value); + } + + return static_cast(std::llround(scaled)); + } -/** + /** * @brief Clamp a signed quantized value to the representable range of BIT_WIDTH bits. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline int32_t clamp_signed_quantized(const int64_t value) { - if constexpr (BIT_WIDTH == 1) { - // 1-bit fallback is treated as binary quantization (0/1). - return static_cast(std::clamp(value, 0, 1)); - } - - constexpr int32_t min_value = -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = (1 << (BIT_WIDTH - 1)) - 1; - return static_cast(std::clamp(value, min_value, max_value)); -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) +__always_inline i32 clamp_signed_quantized(const i64 value) { + if constexpr (BIT_WIDTH == 1) { + // 1-bit fallback is treated as binary quantization (0/1). + return static_cast(std::clamp(value, 0, 1)); + } + + constexpr i32 min_value = -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = (1 << (BIT_WIDTH - 1)) - 1; + return static_cast(std::clamp(value, min_value, max_value)); + } -/** + /** * @brief Append packed scalar values into an output buffer using the selected * storage width. * @@ -83,58 +83,61 @@ __always_inline int32_t clamp_signed_quantized(const int64_t value) { * @param bit_offset starting bit offset in the destination buffer. * @param destination pointer to the output buffer. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) -void pack_epi32_fallback_inner(const std::vector& input, const uint8_t bit_offset, uint8_t* __restrict__ destination) { - constexpr uint32_t bits_in_type = sizeof(T) * 8; - constexpr uint32_t bitmask = BIT_WIDTH == bits_in_type ? std::numeric_limits::max() : (1U << BIT_WIDTH) - 1U; - - std::size_t idx = 0; - std::size_t bits_in_buffer = bit_offset; - uint64_t buffer = bit_offset ? static_cast(destination[0] & ((1U << bit_offset) - 1U)) : 0; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) + void pack_epi32_fallback_inner(const std::vector &input, const u8 bit_offset, + u8 * __restrict__ destination) { + constexpr u32 bits_in_type = sizeof(T) * 8; + constexpr u32 bitmask = BIT_WIDTH == bits_in_type + ? std::numeric_limits::max() + : (1U << BIT_WIDTH) - 1U; + + std::size_t idx = 0; + std::size_t bits_in_buffer = bit_offset; + u64 buffer = bit_offset ? static_cast(destination[0] & ((1U << bit_offset) - 1U)) : 0; #pragma GCC unroll 64 - for (uint32_t raw_value : input) { - const uint32_t next_value = raw_value & bitmask; - - buffer |= static_cast(next_value) << bits_in_buffer; - bits_in_buffer += BIT_WIDTH; - - while (bits_in_buffer >= 8) { - destination[idx++] = static_cast(buffer & 0xFFU); - buffer >>= 8; - bits_in_buffer -= 8; + for (u32 raw_value: input) { + const u32 next_value = raw_value & bitmask; + + buffer |= static_cast(next_value) << bits_in_buffer; + bits_in_buffer += BIT_WIDTH; + + while (bits_in_buffer >= 8) { + destination[idx++] = static_cast(buffer & 0xFFU); + buffer >>= 8; + bits_in_buffer -= 8; + } + } + + if (bits_in_buffer > 0) { + destination[idx] = static_cast(buffer & 0xFFU); + } } - } - - if (bits_in_buffer > 0) { - destination[idx] = static_cast(buffer & 0xFFU); - } -} -/** - * @brief Pack a vector of uint32_t values into a compact byte representation using fallback scalar implementation. + /** + * @brief Pack a vector of u32 values into a compact byte representation using fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). * - * @param input vector of uint32_t values to be packed. + * @param input vector of u32 values to be packed. * @param destination pointer to the output buffer where packed bytes will be stored. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -void pack_epi32_fallback(const std::vector& input, uint8_t* __restrict__ destination) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::pack_epi32_fallback_inner(input, 0, destination); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::pack_epi32_fallback_inner(input, 0, destination); - } else { - return internal::pack_epi32_fallback_inner(input, 0, destination); - } -} -} // namespace internal + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + void pack_epi32_fallback(const std::vector &input, u8 * __restrict__ destination) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::pack_epi32_fallback_inner(input, 0, destination); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::pack_epi32_fallback_inner(input, 0, destination); + } else { + return internal::pack_epi32_fallback_inner(input, 0, destination); + } + } + } // namespace internal -/** + /** * @brief Compress a single 512-bit block using fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -145,28 +148,29 @@ void pack_epi32_fallback(const std::vector& input, uint8_t* __restrict * @param output pointer to the output buffer where compressed bytes will be stored. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_block_fallback(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_block_fallback(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - std::memset(output, 0, BLOCK_SIZE); + std::memset(output, 0, BLOCK_SIZE); - std::vector block_values(elements_per_block); + std::vector block_values(elements_per_block); #pragma GCC unroll 64 - for (uint32_t i = 0; i < elements_per_block; i++) { - const int32_t quantized = internal::quantize_clamped(input[i], scale); - block_values[i] = static_cast(quantized); - } + for (u32 i = 0; i < elements_per_block; i++) { + const i32 quantized = internal::quantize_clamped(input[i], scale); + block_values[i] = static_cast(quantized); + } - internal::pack_epi32_fallback(block_values, output); - return 0; -} + internal::pack_epi32_fallback(block_values, output); + return 0; + } -/** + /** * @brief Compress a single block of double values using the fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -176,28 +180,29 @@ int compress_block_fallback(const void* __restrict__ input_ptr, const float_t sc * @param output pointer to the output buffer where compressed bytes will be stored. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_block_fallback(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_block_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - std::memset(output, 0, BLOCK_SIZE); + std::memset(output, 0, BLOCK_SIZE); - std::vector block_values(elements_per_block); + std::vector block_values(elements_per_block); #pragma GCC unroll 32 - for (uint32_t i = 0; i < elements_per_block; i++) { - const int32_t quantized = internal::quantize_clamped(input[i], scale); - block_values[i] = static_cast(quantized); - } + for (u32 i = 0; i < elements_per_block; i++) { + const i32 quantized = internal::quantize_clamped(input[i], scale); + block_values[i] = static_cast(quantized); + } - internal::pack_epi32_fallback(block_values, output); - return 0; -} + internal::pack_epi32_fallback(block_values, output); + return 0; + } -/** + /** * @brief Compress multiple 512-bit blocks using fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -209,25 +214,26 @@ int compress_block_fallback(const void* __restrict__ input_ptr, const double_t s * @param blocks number of 512-bit blocks to compress. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_blocks_fallback(const void* __restrict__ input_ptr, float scale, void* __restrict__ output_ptr, uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - compress_block_fallback(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_blocks_fallback(const void * __restrict__ input_ptr, float scale, void * __restrict__ output_ptr, + u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + compress_block_fallback(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } -/** + /** * @brief Compress multiple blocks of double values using the fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -238,22 +244,23 @@ int compress_blocks_fallback(const void* __restrict__ input_ptr, float scale, vo * @param blocks number of blocks to compress. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_blocks_fallback(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, - const unsigned int blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - compress_block_fallback(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_blocks_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const unsigned int blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + compress_block_fallback(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + return 0; } - return 0; -} } // namespace pernix #endif // PERNIX_FALLBACK_COMPRESSION_H diff --git a/src/internal/pernix/fallback/avx2_decompression.h b/src/internal/pernix/fallback/avx2_decompression.h index f8f4680..24431e5 100644 --- a/src/internal/pernix/fallback/avx2_decompression.h +++ b/src/internal/pernix/fallback/avx2_decompression.h @@ -9,50 +9,50 @@ #include namespace pernix { -namespace internal { -/** -* @brief Dequantize a single int32_t value to float using the provided scale. + namespace internal { + /** +* @brief Dequantize a single i32 value to float using the provided scale. * -* @param input input int32_t value to be dequantized. +* @param input input i32 value to be dequantized. * @param scale scaling factor used during quantization. * @return float dequantized float value. */ -__always_inline float dequantize_epi32(const int32_t input, const float scale) { - return static_cast(input) * scale; -} +__always_inline float dequantize_epi32(const i32 input, const float scale) { + return static_cast(input) * scale; + } -/** -* @brief Dequantize a single int64_t value to double using the provided scale. + /** +* @brief Dequantize a single i64 value to double using the provided scale. * -* @param input input int64_t value to be dequantized. +* @param input input i64 value to be dequantized. * @param scale scaling factor used during quantization. -* @return double_t dequantized double value. +* @return f64 dequantized double value. */ -__always_inline double_t dequantize_epi64(const int64_t input, const double_t scale) { - return static_cast(input) * scale; -} +__always_inline f64 dequantize_epi64(const i64 input, const f64 scale) { + return static_cast(input) * scale; + } -/** + /** * @brief Sign-extend a packed integer value stored in the low bits of a 32-bit word. * * @tparam BIT_WIDTH number of significant bits in the encoded value. * @param value unsigned packed value. -* @return int32_t sign-extended value. +* @return i32 sign-extended value. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline auto sign_extend(const uint32_t value) -> int32_t { - if constexpr (BIT_WIDTH == 1) { - return static_cast(value & 1U); - } - - constexpr uint32_t sign_bit = uint32_t{1} << (BIT_WIDTH - 1); - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1; - const uint32_t masked = value & mask; - return static_cast((static_cast(masked ^ sign_bit)) - static_cast(sign_bit)); -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __always_inline auto sign_extend(const u32 value) -> i32 { + if constexpr (BIT_WIDTH == 1) { + return static_cast(value & 1U); + } + + constexpr u32 sign_bit = u32{1} << (BIT_WIDTH - 1); + constexpr u32 mask = (u32{1} << BIT_WIDTH) - 1; + const u32 masked = value & mask; + return static_cast((static_cast(masked ^ sign_bit)) - static_cast(sign_bit)); + } -/** + /** * @brief Unpack bit-packed values from a typed input span into signed 32-bit integers. * * @tparam T unsigned integer type used to read the source buffer. @@ -61,70 +61,70 @@ __always_inline auto sign_extend(const uint32_t value) -> int32_t { * @param input pointer to the typed packed input buffer. * @param bit_offset starting bit offset in the first input word. * @param elements number of values to unpack. -* @return std::vector unpacked values. +* @return std::vector unpacked values. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) -__always_inline auto unpack_epi32_fallback_inner(const uint8_t* __restrict__ input, const uint8_t bit_offset, - const std::size_t elements) - -> std::vector { - constexpr uint32_t bits_in_type = sizeof(T) * 8; - constexpr uint32_t bitmask = BIT_WIDTH == bits_in_type - ? std::numeric_limits::max() - : (1U << BIT_WIDTH) - 1U; - - std::vector output(elements); - - std::size_t idx = 0; - uint8_t bits_in_buffer = 8 - bit_offset; - uint64_t buffer = static_cast(input[idx++]) >> bit_offset; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) + __always_inline auto unpack_epi32_fallback_inner(const u8 * __restrict__ input, const u8 bit_offset, + const std::size_t elements) + -> std::vector { + constexpr u32 bits_in_type = sizeof(T) * 8; + constexpr u32 bitmask = BIT_WIDTH == bits_in_type + ? std::numeric_limits::max() + : (1U << BIT_WIDTH) - 1U; + + std::vector output(elements); + + std::size_t idx = 0; + u8 bits_in_buffer = 8 - bit_offset; + u64 buffer = static_cast(input[idx++]) >> bit_offset; #pragma GCC unroll 64 - for (uint32_t i = 0; i < elements; i++) { - while (BIT_WIDTH > bits_in_buffer) { - const auto next_value = static_cast(input[idx++]) << bits_in_buffer; - buffer |= next_value; - bits_in_buffer += 8; - } - - const uint32_t raw_value = static_cast(buffer & bitmask); - if constexpr (SIGN_VALUES) { - output[i] = sign_extend(raw_value); - } else { - output[i] = static_cast(raw_value); + for (u32 i = 0; i < elements; i++) { + while (BIT_WIDTH > bits_in_buffer) { + const auto next_value = static_cast(input[idx++]) << bits_in_buffer; + buffer |= next_value; + bits_in_buffer += 8; + } + + const u32 raw_value = static_cast(buffer & bitmask); + if constexpr (SIGN_VALUES) { + output[i] = sign_extend(raw_value); + } else { + output[i] = static_cast(raw_value); + } + + buffer >>= BIT_WIDTH; + bits_in_buffer -= BIT_WIDTH; + } + + return output; } - buffer >>= BIT_WIDTH; - bits_in_buffer -= BIT_WIDTH; - } - - return output; -} - -/** -* @brief Unpack packed int32_t values from the input buffer using fallback scalar implementation. + /** +* @brief Unpack packed i32 values from the input buffer using fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). * @tparam SIGN_VALUES whether the values are signed or unsigned. * @param input pointer to the start of the packed data. * @param elements number of elements to unpack. -* @return std::vector unpacked int32_t values. +* @return std::vector unpacked i32 values. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline auto unpack_epi32_fallback(const uint8_t* __restrict__ input, - const std::size_t elements) -> std::vector { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return unpack_epi32_fallback_inner(input, 0, elements); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return unpack_epi32_fallback_inner(input, 0, elements); - } else { - return unpack_epi32_fallback_inner(input, 0, elements); - } -} -} // namespace internal + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __always_inline auto unpack_epi32_fallback(const u8 * __restrict__ input, + const std::size_t elements) -> std::vector { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return unpack_epi32_fallback_inner(input, 0, elements); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return unpack_epi32_fallback_inner(input, 0, elements); + } else { + return unpack_epi32_fallback_inner(input, 0, elements); + } + } + } // namespace internal -/** + /** * @brief Decompress a single 512\-bit block using fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -134,27 +134,27 @@ __always_inline auto unpack_epi32_fallback(const uint8_t* __restrict__ input, * @param output pointer to the output buffer where decompressed float values will be stored. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_block_fallback(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + int decompress_block_fallback(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - const std::vector block_values = internal::unpack_epi32_fallback( - input, elements_per_block); + const std::vector block_values = internal::unpack_epi32_fallback( + input, elements_per_block); #pragma GCC unroll 512 - for (uint32_t i = 0; i < elements_per_block; i++) { - output[i] = internal::dequantize_epi32(block_values[i], scale); - } + for (u32 i = 0; i < elements_per_block; i++) { + output[i] = internal::dequantize_epi32(block_values[i], scale); + } - return 0; -} + return 0; + } -/** + /** * @brief Decompress a single block to double values using the fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -164,27 +164,27 @@ int decompress_block_fallback(const void* __restrict__ input_ptr, const float_t * @param output pointer to the output buffer where decompressed double values will be stored. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_block_fallback(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + int decompress_block_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - const std::vector block_values = internal::unpack_epi32_fallback( - input, elements_per_block); + const std::vector block_values = internal::unpack_epi32_fallback( + input, elements_per_block); #pragma GCC unroll 512 - for (uint32_t i = 0; i < elements_per_block; i++) { - output[i] = internal::dequantize_epi64(block_values[i], scale); - } + for (u32 i = 0; i < elements_per_block; i++) { + output[i] = internal::dequantize_epi64(block_values[i], scale); + } - return 0; -} + return 0; + } -/** + /** * @brief Decompress multiple 512\-bit blocks using fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -195,26 +195,26 @@ int decompress_block_fallback(const void* __restrict__ input_ptr, const double_t * @param blocks number of 512-bit blocks to decompress. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int decompress_blocks_fallback(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr, const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - decompress_block_fallback(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int decompress_blocks_fallback(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + decompress_block_fallback(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } -/** + /** * @brief Decompress multiple blocks to double values using the fallback scalar implementation. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -225,24 +225,24 @@ int decompress_blocks_fallback(const void* __restrict__ input_ptr, const float_t * @param blocks number of blocks to decompress. * @return int status code (0 for success). */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_blocks_fallback(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr, const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - decompress_block_fallback(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + int decompress_blocks_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + decompress_block_fallback(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } } // namespace pernix #endif // PERNIX_FALLBACK_DECOMPRESSION_H diff --git a/src/internal/pernix/simd_compat.h b/src/internal/pernix/simd_compat.h index b9dac8e..d55aafa 100644 --- a/src/internal/pernix/simd_compat.h +++ b/src/internal/pernix/simd_compat.h @@ -22,19 +22,6 @@ #include #endif -// #ifndef __mmask8 -// typedef uint8_t __mmask8; -// #endif -// #ifndef __mmask16 -// typedef uint16_t __mmask16; -// #endif -// #ifndef __mmask32 -// typedef uint32_t __mmask32; -// #endif -// #ifndef __mmask64 -// typedef uint64_t __mmask64; -// #endif - #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include #elif defined(__aarch64__) || defined(__arm64ec__) diff --git a/src/internal/pernix/x86/avx2/avx2_compression.h b/src/internal/pernix/x86/avx2/avx2_compression.h index 3486f02..01c9626 100644 --- a/src/internal/pernix/x86/avx2/avx2_compression.h +++ b/src/internal/pernix/x86/avx2/avx2_compression.h @@ -11,68 +11,68 @@ #include namespace pernix { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + namespace internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) __always_inline __m256i mm256_clamp_signed_epi32(__m256i input) { - constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), - _mm256_set1_epi32(max_value)); -} + constexpr i32 min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); + return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), + _mm256_set1_epi32(max_value)); + } -/** + /** * @brief Quantize four float values into signed 32-bit integers. * * @param input source float lane values. * @param scale per-lane scale factor. * @return __m128i quantized values. */ -__always_inline __m128i mm_quantize_ps_epi32(const __m128& input, const __m128& scale) { - const __m128 scaled = _mm_mul_ps(input, scale); - // const __m128 rounded = _mm_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - return _mm_cvtps_epi32(scaled); -} +__always_inline __m128i mm_quantize_ps_epi32(const __m128 &input, const __m128 &scale) { + const __m128 scaled = _mm_mul_ps(input, scale); + // const __m128 rounded = _mm_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return _mm_cvtps_epi32(scaled); + } -/** + /** * @brief Quantize two double values into a partially filled 128-bit integer register. * * @param input source double lane values. * @param scale per-lane scale factor. * @return __m128i quantized values in the low lanes. */ -__always_inline __m128i mm_quantize_pd_epi32(const __m128d& input, const __m128d& scale) { - const __m128d scaled = _mm_mul_pd(input, scale); - return _mm_cvtpd_epi32(scaled); -} +__always_inline __m128i mm_quantize_pd_epi32(const __m128d &input, const __m128d &scale) { + const __m128d scaled = _mm_mul_pd(input, scale); + return _mm_cvtpd_epi32(scaled); + } -/** + /** * @brief Quantize eight float values into signed 32-bit integers. * * @param input source float lane values. * @param scale per-lane scale factor. * @return __m256i quantized values. */ -__always_inline __m256i mm256_quantize_ps_epi32(const __m256& input, const __m256& scale) { - const __m256 scaled = _mm256_mul_ps(input, scale); - // const __m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - return _mm256_cvtps_epi32(scaled); -} +__always_inline __m256i mm256_quantize_ps_epi32(const __m256 &input, const __m256 &scale) { + const __m256 scaled = _mm256_mul_ps(input, scale); + // const __m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return _mm256_cvtps_epi32(scaled); + } -/** + /** * @brief Quantize four double values into signed 32-bit integers. * * @param input source double lane values. * @param scale per-lane scale factor. * @return __m128i quantized values in the low lanes. */ -__always_inline __m128i mm256_quantize_pd_epi32(const __m256d& input, const __m256d& scale) { - const __m256d scaled = _mm256_mul_pd(input, scale); - return _mm256_cvtpd_epi32(scaled); -} +__always_inline __m128i mm256_quantize_pd_epi32(const __m256d &input, const __m256d &scale) { + const __m256d scaled = _mm256_mul_pd(input, scale); + return _mm256_cvtpd_epi32(scaled); + } #ifndef PERNIX_USE_SIMDE -/** + /** * @brief Emulate per-16-bit left shifts on AVX2. * * @param a source values. @@ -80,350 +80,350 @@ __always_inline __m128i mm256_quantize_pd_epi32(const __m256d& input, const __m2 * @return __m128i shifted values. */ __always_inline static __m128i _mm_sllv_epi16(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi32(0xffff0000); - const __m128i low_half = _mm_sllv_epi32(a, _mm_andnot_si128(mask, count)); - const __m128i high_half = _mm_sllv_epi32(_mm_and_si128(mask, a), _mm_srli_epi32(count, 16)); - return _mm_blend_epi16(low_half, high_half, 0xaa); -} + const __m128i mask = _mm_set1_epi32(0xffff0000); + const __m128i low_half = _mm_sllv_epi32(a, _mm_andnot_si128(mask, count)); + const __m128i high_half = _mm_sllv_epi32(_mm_and_si128(mask, a), _mm_srli_epi32(count, 16)); + return _mm_blend_epi16(low_half, high_half, 0xaa); + } -/** + /** * @brief Emulate per-16-bit right shifts on AVX2. */ __always_inline static __m128i _mm_srlv_epi16(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi32(0x0000ffff); - const __m128i low_half = _mm_srlv_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); - const __m128i high_half = _mm_srlv_epi32(a, _mm_srli_epi32(count, 16)); - return _mm_blend_epi16(low_half, high_half, 0xaa); -} + const __m128i mask = _mm_set1_epi32(0x0000ffff); + const __m128i low_half = _mm_srlv_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); + const __m128i high_half = _mm_srlv_epi32(a, _mm_srli_epi32(count, 16)); + return _mm_blend_epi16(low_half, high_half, 0xaa); + } -/** + /** * @brief Emulate per-16-bit left shifts on 256-bit AVX2 vectors. */ __always_inline static __m256i _mm256_sllv_epi16(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi32(0xffff0000); - const __m256i low_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); - const __m256i high_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); - return _mm256_blend_epi16(low_half, high_half, 0xaa); -} + const __m256i mask = _mm256_set1_epi32(0xffff0000); + const __m256i low_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); + const __m256i high_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(low_half, high_half, 0xaa); + } -/** + /** * @brief Emulate per-16-bit right shifts on 256-bit AVX2 vectors. */ __always_inline static __m256i _mm256_srlv_epi16(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi32(0x0000ffff); - const __m256i low_half = _mm256_srlv_epi32(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); - const __m256i high_half = _mm256_srlv_epi32(a, _mm256_srli_epi32(count, 16)); - return _mm256_blend_epi16(low_half, high_half, 0xaa); -} + const __m256i mask = _mm256_set1_epi32(0x0000ffff); + const __m256i low_half = _mm256_srlv_epi32(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); + const __m256i high_half = _mm256_srlv_epi32(a, _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(low_half, high_half, 0xaa); + } -/** + /** * @brief Blend 8-bit lanes by expanding a scalar mask value. */ -__always_inline static __m128i mm_blend_epi8(const __m128i X, const __m128i Y, const int8_t M) { - return _mm_blendv_epi8(X, Y, _mm_set1_epi8(M)); -} +__always_inline static __m128i mm_blend_epi8(const __m128i X, const __m128i Y, const i8 M) { + return _mm_blendv_epi8(X, Y, _mm_set1_epi8(M)); + } -/** + /** * @brief Blend 8-bit lanes in 256-bit vectors by expanding a scalar mask value. */ -__always_inline static __m256i mm256_blend_epi8(const __m256i X, const __m256i Y, const int8_t M) { - return _mm256_blendv_epi8(X, Y, _mm256_set1_epi8(M)); -} +__always_inline static __m256i mm256_blend_epi8(const __m256i X, const __m256i Y, const i8 M) { + return _mm256_blendv_epi8(X, Y, _mm256_set1_epi8(M)); + } -/** + /** * @brief Emulate per-byte left shifts on 128-bit vectors. */ __always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0xff00); - const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); - const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); - return mm_blend_epi8(low_half, high_half, 0xaa); -} + const __m128i mask = _mm_set1_epi16(0xff00); + const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); + const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); + return mm_blend_epi8(low_half, high_half, 0xaa); + } -/** + /** * @brief Emulate per-byte right shifts on 128-bit vectors. */ __always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0x00ff); - const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); - const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); - return mm_blend_epi8(low_half, high_half, 0xaa); -} + const __m128i mask = _mm_set1_epi16(0x00ff); + const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); + const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); + return mm_blend_epi8(low_half, high_half, 0xaa); + } -/** + /** * @brief Emulate per-byte left shifts on 256-bit vectors. */ __always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0xff00); - const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); - const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); - return mm256_blend_epi8(low_half, high_half, 0xaa); -} + const __m256i mask = _mm256_set1_epi16(0xff00); + const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); + const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); + return mm256_blend_epi8(low_half, high_half, 0xaa); + } -/** + /** * @brief Emulate per-byte right shifts on 256-bit vectors. */ __always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0x00ff); - const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); - const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); - return mm256_blend_epi8(low_half, high_half, 0xaa); -} + const __m256i mask = _mm256_set1_epi16(0x00ff); + const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); + const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); + return mm256_blend_epi8(low_half, high_half, 0xaa); + } #endif -/** + /** * @brief Pack four 32-bit values for bit widths 1 through 3. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) -__always_inline auto mm_pack_epi32_avx2_1to3(__m128i& input) -> __m128i { - constexpr uint32_t bitmask = (1U << BIT_WIDTH) - 1U; - const __m128i masked = _mm_and_si128(input, _mm_set1_epi32(static_cast(bitmask))); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) + __always_inline auto mm_pack_epi32_avx2_1to3(__m128i &input) -> __m128i { + constexpr u32 bitmask = (1U << BIT_WIDTH) - 1U; + const __m128i masked = _mm_and_si128(input, _mm_set1_epi32(static_cast(bitmask))); - alignas(16) uint32_t lanes[4]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(lanes), masked); + alignas(16) u32 lanes[4]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(lanes), masked); - const uint32_t packed = (lanes[0] & bitmask) | ((lanes[1] & bitmask) << BIT_WIDTH) | ( - (lanes[2] & bitmask) << (2 * BIT_WIDTH)) | - ((lanes[3] & bitmask) << (3 * BIT_WIDTH)); + const u32 packed = (lanes[0] & bitmask) | ((lanes[1] & bitmask) << BIT_WIDTH) | ( + (lanes[2] & bitmask) << (2 * BIT_WIDTH)) | + ((lanes[3] & bitmask) << (3 * BIT_WIDTH)); - return _mm_cvtsi32_si128(static_cast(packed)); -} + return _mm_cvtsi32_si128(static_cast(packed)); + } -/** + /** * @brief Pack eight 32-bit values for bit widths 1 through 3. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) -__always_inline __m256i mm256_pack_epi32_avx2_1to3(const __m256i& input) { - constexpr uint32_t bitmask = (1u << BIT_WIDTH) - 1u; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) +__always_inline __m256i mm256_pack_epi32_avx2_1to3(const __m256i &input) { + constexpr u32 bitmask = (1u << BIT_WIDTH) - 1u; - const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(static_cast(bitmask))); + const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(static_cast(bitmask))); - const __m256i shifts = _mm256_setr_epi32(0 * BIT_WIDTH, 1 * BIT_WIDTH, 2 * BIT_WIDTH, 3 * BIT_WIDTH, - 4 * BIT_WIDTH, 5 * BIT_WIDTH, - 6 * BIT_WIDTH, 7 * BIT_WIDTH); + const __m256i shifts = _mm256_setr_epi32(0 * BIT_WIDTH, 1 * BIT_WIDTH, 2 * BIT_WIDTH, 3 * BIT_WIDTH, + 4 * BIT_WIDTH, 5 * BIT_WIDTH, + 6 * BIT_WIDTH, 7 * BIT_WIDTH); - const __m256i shifted = _mm256_sllv_epi32(masked, shifts); + const __m256i shifted = _mm256_sllv_epi32(masked, shifts); - __m128i x = _mm_or_si128(_mm256_castsi256_si128(shifted), _mm256_extracti128_si256(shifted, 1)); + __m128i x = _mm_or_si128(_mm256_castsi256_si128(shifted), _mm256_extracti128_si256(shifted, 1)); - x = _mm_or_si128(x, _mm_srli_si128(x, 8)); - x = _mm_or_si128(x, _mm_srli_si128(x, 4)); + x = _mm_or_si128(x, _mm_srli_si128(x, 8)); + x = _mm_or_si128(x, _mm_srli_si128(x, 4)); - return _mm256_castsi128_si256(x); -} + return _mm256_castsi128_si256(x); + } -__always_inline __m256i mm256_pack_epi32_avx2_4(const __m256i& input) { - const __m256i zero = _mm256_setzero_si256(); +__always_inline __m256i mm256_pack_epi32_avx2_4(const __m256i &input) { + const __m256i zero = _mm256_setzero_si256(); - const __m256i packed16 = _mm256_packus_epi32(input, zero); - const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); - const __m256i packed8 = _mm256_packus_epi16(permuted, zero); + const __m256i packed16 = _mm256_packus_epi32(input, zero); + const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i packed8 = _mm256_packus_epi16(permuted, zero); - const __m256i combined = _mm256_or_si256(packed8, _mm256_srli_epi16(packed8, 4)); + const __m256i combined = _mm256_or_si256(packed8, _mm256_srli_epi16(packed8, 4)); - const __m256i shuffled = _mm256_shuffle_epi8(combined, _mm256_setr_epi8( - 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, - 0, 2, - 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1)); + const __m256i shuffled = _mm256_shuffle_epi8(combined, _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 2, + 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1)); - return shuffled; -} + return shuffled; + } -/** + /** * @brief Pack four 32-bit values for bit widths 9 through 16. */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline auto mm_pack_epi32_avx2_9to16(__m128i& input) -> __m128i { - using tables = pack_tables_avx2_16; - constexpr uint16_t bitmask = (1 << BIT_WIDTH) - 1; - const __m128i masked = _mm_and_si128(input, _mm_set1_epi16(bitmask)); - __m128i combined; - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); - - combined = _mm_or_si128(shifted1, shifted2); - } else { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); - - combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); - } - return combined; -} + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) + __always_inline auto mm_pack_epi32_avx2_9to16(__m128i &input) -> __m128i { + using tables = pack_tables_avx2_16; + constexpr u16 bitmask = (1 << BIT_WIDTH) - 1; + const __m128i masked = _mm_and_si128(input, _mm_set1_epi16(bitmask)); + __m128i combined; + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); + + combined = _mm_or_si128(shifted1, shifted2); + } else { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); + + combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + return combined; + } -/** + /** * @brief Pack eight 32-bit values for bit widths 9 through 16. */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline auto mm256_pack_epi32_avx2_9to16(const __m256i& input) -> __m256i { - using tables = pack_tables_avx2_16; - constexpr uint16_t bitmask = (1 << BIT_WIDTH) - 1; - const __m128i packed = _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1)); - const __m128i masked = _mm_and_si128(packed, _mm_set1_epi16(bitmask)); - __m128i combined; - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15 || BIT_WIDTH == 16) { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); - - combined = _mm_or_si128(shifted1, shifted2); - } else { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); - - combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); - } - return _mm256_castsi128_si256(combined); -} + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) + __always_inline auto mm256_pack_epi32_avx2_9to16(const __m256i &input) -> __m256i { + using tables = pack_tables_avx2_16; + constexpr u16 bitmask = (1 << BIT_WIDTH) - 1; + const __m128i packed = _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1)); + const __m128i masked = _mm_and_si128(packed, _mm_set1_epi16(bitmask)); + __m128i combined; + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15 || BIT_WIDTH == 16) { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); + + combined = _mm_or_si128(shifted1, shifted2); + } else { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); + + combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + return _mm256_castsi128_si256(combined); + } -template - requires(BIT_WIDTH >= 5 && BIT_WIDTH <= 7) -__always_inline auto mm256_pack_epi32_avx2_5to7(const __m256i& input) -> __m256i { - const __m256i zero = _mm256_setzero_si256(); + template + requires(BIT_WIDTH >= 5 && BIT_WIDTH <= 7) + __always_inline auto mm256_pack_epi32_avx2_5to7(const __m256i &input) -> __m256i { + const __m256i zero = _mm256_setzero_si256(); - const __m256i packed16 = _mm256_packus_epi32(input, zero); - const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); - const __m256i packed8 = _mm256_packus_epi16(permuted, zero); + const __m256i packed16 = _mm256_packus_epi32(input, zero); + const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i packed8 = _mm256_packus_epi16(permuted, zero); - const __m256i even = _mm256_and_si256(packed8, _mm256_set1_epi16(0x00FF)); - const __m256i odd = _mm256_and_si256(packed8, _mm256_set1_epi16(0xFF00)); + const __m256i even = _mm256_and_si256(packed8, _mm256_set1_epi16(0x00FF)); + const __m256i odd = _mm256_and_si256(packed8, _mm256_set1_epi16(0xFF00)); - const __m256i pair16 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); - const __m256i extended = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(pair16)); + const __m256i pair16 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); + const __m256i extended = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(pair16)); - return mm256_pack_epi32_avx2_9to16<2 * BIT_WIDTH>(extended); -} + return mm256_pack_epi32_avx2_9to16<2 * BIT_WIDTH>(extended); + } -/** + /** * @brief Pack eight 32-bit values for bit widths 17 through 24. */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline auto mm256_pack_epi32_avx2_17to24(const __m256i& input) -> __m256i { - using tables = pack_tables_avx2_24; - constexpr uint32_t bitmask = (1 << BIT_WIDTH) - 1; - const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(bitmask)); - __m256i combined; - - if constexpr (BIT_WIDTH == 24) { - const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); - const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); - - const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); - const __m256i shifted2 = _mm256_srlv_epi32(shuffled2, tables::get_shift2()); - - combined = _mm256_or_si256(shifted1, shifted2); - } else { - const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); - const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); - const __m256i shuffled3 = _mm256_permutevar8x32_epi32(masked, tables::get_permute3()); - - const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi32(shuffled2, tables::get_shift2()); - const __m256i shifted3 = _mm256_srlv_epi32(shuffled3, tables::get_shift3()); - - combined = _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); - } - - return combined; -} + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) + __always_inline auto mm256_pack_epi32_avx2_17to24(const __m256i &input) -> __m256i { + using tables = pack_tables_avx2_24; + constexpr u32 bitmask = (1 << BIT_WIDTH) - 1; + const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(bitmask)); + __m256i combined; + + if constexpr (BIT_WIDTH == 24) { + const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); + const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); + + const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); + const __m256i shifted2 = _mm256_srlv_epi32(shuffled2, tables::get_shift2()); + + combined = _mm256_or_si256(shifted1, shifted2); + } else { + const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); + const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); + const __m256i shuffled3 = _mm256_permutevar8x32_epi32(masked, tables::get_permute3()); + + const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi32(shuffled2, tables::get_shift2()); + const __m256i shifted3 = _mm256_srlv_epi32(shuffled3, tables::get_shift3()); + + combined = _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); + } + + return combined; + } -/** + /** * @brief Pack aligned 8-bit or 16-bit values from four 32-bit lanes. */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -auto mm_pack_aligned_epi32_avx2(__m128i& input) -> __m128i { - if constexpr (BIT_WIDTH == 8) { - return _mm_packus_epi16(_mm_packs_epi32(input, _mm_setzero_si128()), _mm_setzero_si128()); - } else { - return _mm_packs_epi32(input, _mm_setzero_si128()); - } -} + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + auto mm_pack_aligned_epi32_avx2(__m128i &input) -> __m128i { + if constexpr (BIT_WIDTH == 8) { + return _mm_packus_epi16(_mm_packs_epi32(input, _mm_setzero_si128()), _mm_setzero_si128()); + } else { + return _mm_packs_epi32(input, _mm_setzero_si128()); + } + } -/** + /** * @brief Dispatch to the appropriate 128-bit AVX2 packer for the selected bit width. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 16) -auto mm_pack_epi32_avx2(__m128i& input) -> __m128i { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { - return internal::mm_pack_epi32_avx2_1to3(input); - } else if constexpr (BIT_WIDTH >= 4 && BIT_WIDTH <= 8) { - // TODO: implementation for 4-8 bits - return _mm_setzero_si128(); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::mm_pack_epi32_avx2_9to16(input); - } else { - return _mm_setzero_si128(); - } -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 16) + auto mm_pack_epi32_avx2(__m128i &input) -> __m128i { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { + return internal::mm_pack_epi32_avx2_1to3(input); + } else if constexpr (BIT_WIDTH >= 4 && BIT_WIDTH <= 8) { + // TODO: implementation for 4-8 bits + return _mm_setzero_si128(); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::mm_pack_epi32_avx2_9to16(input); + } else { + return _mm_setzero_si128(); + } + } -/** + /** * @brief Pack aligned 8-bit or 16-bit values from eight 32-bit lanes. */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -__m256i mm256_pack_aligned_epi32_avx2(const __m256i& input) { - if constexpr (BIT_WIDTH == 8) { - const __m128i packed16 = _mm_packs_epi32(_mm256_castsi256_si128(input), - _mm256_extracti128_si256(input, 1)); - const __m128i packed8 = _mm_packs_epi16(packed16, _mm_setzero_si128()); - return _mm256_castsi128_si256(packed8); - } else { - return _mm256_castsi128_si256( - _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1))); - } -} + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + __m256i mm256_pack_aligned_epi32_avx2(const __m256i &input) { + if constexpr (BIT_WIDTH == 8) { + const __m128i packed16 = _mm_packs_epi32(_mm256_castsi256_si128(input), + _mm256_extracti128_si256(input, 1)); + const __m128i packed8 = _mm_packs_epi16(packed16, _mm_setzero_si128()); + return _mm256_castsi128_si256(packed8); + } else { + return _mm256_castsi128_si256( + _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1))); + } + } -/** + /** * @brief Dispatch to the appropriate 256-bit AVX2 packer for the selected bit width. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m256i mm256_pack_epi32_avx2(const __m256i& input) { - if constexpr (BIT_WIDTH == 8 || BIT_WIDTH == 16) { - return internal::mm256_pack_aligned_epi32_avx2(input); - } else { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { - return internal::mm256_pack_epi32_avx2_1to3(input); - } else if constexpr (BIT_WIDTH == 4) { - return mm256_pack_epi32_avx2_4(input); - } else if constexpr (BIT_WIDTH >= 5 && BIT_WIDTH <= 7) { - return internal::mm256_pack_epi32_avx2_5to7(input); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 15) { - return internal::mm256_pack_epi32_avx2_9to16(input); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::mm256_pack_epi32_avx2_17to24(input); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m256i mm256_pack_epi32_avx2(const __m256i &input) { + if constexpr (BIT_WIDTH == 8 || BIT_WIDTH == 16) { + return internal::mm256_pack_aligned_epi32_avx2(input); + } else { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { + return internal::mm256_pack_epi32_avx2_1to3(input); + } else if constexpr (BIT_WIDTH == 4) { + return mm256_pack_epi32_avx2_4(input); + } else if constexpr (BIT_WIDTH >= 5 && BIT_WIDTH <= 7) { + return internal::mm256_pack_epi32_avx2_5to7(input); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 15) { + return internal::mm256_pack_epi32_avx2_9to16(input); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::mm256_pack_epi32_avx2_17to24(input); + } + } + return _mm256_setzero_si256(); } - } - return _mm256_setzero_si256(); -} -} // namespace internal + } // namespace internal -/** + /** * @brief Compress a single block of float using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -436,48 +436,48 @@ __m256i mm256_pack_epi32_avx2(const __m256i& input) { * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_avx2(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; - std::memset(output, 0, BLOCK_SIZE); + std::memset(output, 0, BLOCK_SIZE); - const __m256 scale_v = _mm256_set1_ps(scale); + const __m256 scale_v = _mm256_set1_ps(scale); #pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); - const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); - const __m256i packed = internal::mm256_pack_epi32_avx2(packed_input); - std::memcpy(output, &packed, BIT_WIDTH); - - input += 8; - output += BIT_WIDTH; - } + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); + const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); + const __m256i packed = internal::mm256_pack_epi32_avx2(packed_input); + std::memcpy(output, &packed, BIT_WIDTH); + + input += 8; + output += BIT_WIDTH; + } - if constexpr (remaining) { - std::vector block_values(remaining); + if constexpr (remaining) { + std::vector block_values(remaining); #pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized( - internal::quantize_ps_epi32(input[i], scale))); + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_ps_epi32(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); } - internal::pack_epi32_fallback(block_values, output); + return 0; } - return 0; -} - -/** + /** * @brief Compress a single block of double using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -490,52 +490,52 @@ int mm256_compress_block_avx2(const void* __restrict__ input_ptr, const float_t * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_avx2(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; - std::memset(output, 0, BLOCK_SIZE); + std::memset(output, 0, BLOCK_SIZE); - const __m256d scale_v = _mm256_set1_pd(scale); + const __m256d scale_v = _mm256_set1_pd(scale); #pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256d source1 = _mm256_loadu_pd(input); - const __m256d source2 = _mm256_loadu_pd(input + 4); - const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); - const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); - __m256i combined = _mm256_castsi128_si256(quantized1); - combined = _mm256_inserti128_si256(combined, quantized2, 1); - const __m256i packed = internal::mm256_pack_epi32_avx2( - internal::mm256_clamp_signed_epi32(combined)); - // _mm_storeu_si128(reinterpret_cast<__m128i*>(output), _mm256_castsi256_si128(packed)); - std::memcpy(output, &packed, BIT_WIDTH); - input += 8; - output += BIT_WIDTH; - } + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256d source1 = _mm256_loadu_pd(input); + const __m256d source2 = _mm256_loadu_pd(input + 4); + const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); + const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); + __m256i combined = _mm256_castsi128_si256(quantized1); + combined = _mm256_inserti128_si256(combined, quantized2, 1); + const __m256i packed = internal::mm256_pack_epi32_avx2( + internal::mm256_clamp_signed_epi32(combined)); + // _mm_storeu_si128(reinterpret_cast<__m128i*>(output), _mm256_castsi256_si128(packed)); + std::memcpy(output, &packed, BIT_WIDTH); + input += 8; + output += BIT_WIDTH; + } - if constexpr (remaining) { - std::vector block_values(remaining); + if constexpr (remaining) { + std::vector block_values(remaining); #pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized( - internal::quantize_pd_epi64(input[i], scale))); + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_pd_epi64(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); } - internal::pack_epi32_fallback(block_values, output); + return 0; } - return 0; -} - -/** + /** * @brief Compress multiple blocks using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -549,27 +549,27 @@ int mm256_compress_block_avx2(const void* __restrict__ input_ptr, const double_t * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_avx2(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_avx2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_avx2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } -/** + /** * @brief Compress multiple blocks using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -583,25 +583,25 @@ int mm256_compress_blocks_avx2(const void* __restrict__ input_ptr, const float_t * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_avx2(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_avx2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_avx2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } } // namespace pernix #endif // PERNIX_AVX2_COMPRESSION_H diff --git a/src/internal/pernix/x86/avx2/avx2_decompression.h b/src/internal/pernix/x86/avx2/avx2_decompression.h index 7c530d8..2f50018 100644 --- a/src/internal/pernix/x86/avx2/avx2_decompression.h +++ b/src/internal/pernix/x86/avx2/avx2_decompression.h @@ -11,173 +11,176 @@ #include namespace pernix { -namespace internal { -/** + namespace internal { + /** * @brief Convert an 8-lane mask to the lane representation used by AVX2 float masked stores. */ __always_inline __m256i mm256_convert_vmask_epi32(const __mmask8 mask8) { - return _mm256_setr_epi32((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, (mask8 & 0x8) ? -1 : 0, - (mask8 & 0x10) ? -1 : 0, (mask8 & 0x20) ? -1 : 0, (mask8 & 0x40) ? -1 : 0, (mask8 & 0x80) ? -1 : 0); -} + return _mm256_setr_epi32((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, + (mask8 & 0x8) ? -1 : 0, + (mask8 & 0x10) ? -1 : 0, (mask8 & 0x20) ? -1 : 0, (mask8 & 0x40) ? -1 : 0, + (mask8 & 0x80) ? -1 : 0); + } -/** + /** * @brief Convert a 4-lane mask to the lane representation used by AVX2 double masked stores. */ __always_inline __m256i mm256_convert_vmask_epi64(const __mmask8 mask8) { - return _mm256_setr_epi64x((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, (mask8 & 0x8) ? -1 : 0); -} + return _mm256_setr_epi64x((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, + (mask8 & 0x8) ? -1 : 0); + } -/** + /** * @brief Dequantize four 32-bit integers to floats. */ -__always_inline __m128 mm_dequantize_epi32(const __m128i& input, const __m128& scale) { - const __m128 converted = _mm_cvtepi32_ps(input); - return _mm_mul_ps(converted, scale); -} +__always_inline __m128 mm_dequantize_epi32(const __m128i &input, const __m128 &scale) { + const __m128 converted = _mm_cvtepi32_ps(input); + return _mm_mul_ps(converted, scale); + } -/* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ + /* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ __always_inline __m128d convert_epi64_pd(const __m128i v) { - __m128i xH = _mm_srai_epi32(v, 16); - xH = _mm_blend_epi16(xH, _mm_setzero_si128(), 0x33); - xH = _mm_add_epi64(xH, _mm_castpd_si128(_mm_set1_pd(442721857769029238784.))); // 3*2^67 - const __m128i xL = _mm_blend_epi16(v, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0x88); // 2^52 - const __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(442726361368656609280.)); // 3*2^67 + 2^52 - return _mm_add_pd(f, _mm_castsi128_pd(xL)); -} - -/** + __m128i xH = _mm_srai_epi32(v, 16); + xH = _mm_blend_epi16(xH, _mm_setzero_si128(), 0x33); + xH = _mm_add_epi64(xH, _mm_castpd_si128(_mm_set1_pd(442721857769029238784.))); // 3*2^67 + const __m128i xL = _mm_blend_epi16(v, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0x88); // 2^52 + const __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(442726361368656609280.)); // 3*2^67 + 2^52 + return _mm_add_pd(f, _mm_castsi128_pd(xL)); + } + + /** * @brief Dequantize two 64-bit integers to doubles. */ -__always_inline __m128d mm_dequantize_epi64_pd(const __m128i& input, const __m128d& scale) { - const __m128d converted = convert_epi64_pd(input); - return _mm_mul_pd(converted, scale); -} +__always_inline __m128d mm_dequantize_epi64_pd(const __m128i &input, const __m128d &scale) { + const __m128d converted = convert_epi64_pd(input); + return _mm_mul_pd(converted, scale); + } -/** + /** * @brief Dequantize eight 32-bit integers to floats. */ -__always_inline __m256 mm256_dequantize_epi32(const __m256i& input, const __m256& scale) { - const __m256 converted = _mm256_cvtepi32_ps(input); - return _mm256_mul_ps(converted, scale); -} +__always_inline __m256 mm256_dequantize_epi32(const __m256i &input, const __m256 &scale) { + const __m256 converted = _mm256_cvtepi32_ps(input); + return _mm256_mul_ps(converted, scale); + } -/* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ + /* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ __always_inline __m256d convert_epi64_pd(__m256i v) { - __m256i xH = _mm256_srai_epi32(v, 16); - xH = _mm256_blend_epi16(xH, _mm256_setzero_si256(), 0x33); - xH = _mm256_add_epi64(xH, _mm256_castpd_si256(_mm256_set1_pd(442721857769029238784.))); - const __m256i xL = _mm256_blend_epi16(v, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0x88); - const __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(442726361368656609280.)); - return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); -} - -/** + __m256i xH = _mm256_srai_epi32(v, 16); + xH = _mm256_blend_epi16(xH, _mm256_setzero_si256(), 0x33); + xH = _mm256_add_epi64(xH, _mm256_castpd_si256(_mm256_set1_pd(442721857769029238784.))); + const __m256i xL = _mm256_blend_epi16(v, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0x88); + const __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(442726361368656609280.)); + return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); + } + + /** * @brief Dequantize four 64-bit integers to doubles. */ -__always_inline __m256d mm256_dequantize_epi64_pd(const __m256i& input, const __m256d& scale) { - const __m256d converted = convert_epi64_pd(input); - return _mm256_mul_pd(converted, scale); -} +__always_inline __m256d mm256_dequantize_epi64_pd(const __m256i &input, const __m256d &scale) { + const __m256d converted = convert_epi64_pd(input); + return _mm256_mul_pd(converted, scale); + } -/** + /** * @brief Unpack four aligned 8-bit or 16-bit values directly from the input buffer. */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -__m128i mm_unpack_aligned_epi32_avx2(const uint8_t* __restrict__ input) { - if constexpr (BIT_WIDTH == 8) { - const __m128i source = _mm_loadu_si32(input); - if constexpr (SIGN_VALUES) { - return _mm_cvtepi8_epi32(source); - } else { - return _mm_cvtepu8_epi32(source); - } - } else { - const __m128i source = _mm_loadu_si64(input); - if constexpr (SIGN_VALUES) { - return _mm_cvtepi16_epi32(source); - } else { - return _mm_cvtepu16_epi32(source); + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + __m128i mm_unpack_aligned_epi32_avx2(const u8 * __restrict__ input) { + if constexpr (BIT_WIDTH == 8) { + const __m128i source = _mm_loadu_si32(input); + if constexpr (SIGN_VALUES) { + return _mm_cvtepi8_epi32(source); + } else { + return _mm_cvtepu8_epi32(source); + } + } else { + const __m128i source = _mm_loadu_si64(input); + if constexpr (SIGN_VALUES) { + return _mm_cvtepi16_epi32(source); + } else { + return _mm_cvtepu16_epi32(source); + } + } } - } -} -/** + /** * @brief Unpack four values using the table-driven AVX2 shuffle path. */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -__m128i mm_unpack_epi32_avx2(const uint8_t* __restrict__ input) { - using unpack_table = unpack_tables_avx2; - constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; - - __m128i source = _mm_setzero_si128(); - std::memcpy(&source, input, packed_bytes); - - const __m128i shuffled = _mm_shuffle_epi8(source, unpack_table::get_shuffle()); - - constexpr uint16_t shift = 32 - BIT_WIDTH; - __m128i shifted = _mm_sllv_epi32(shuffled, unpack_table::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm_srai_epi32(shifted, shift); - } else { - shifted = _mm_srli_epi32(shifted, shift); - } - - return shifted; -} + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + __m128i mm_unpack_epi32_avx2(const u8 * __restrict__ input) { + using unpack_table = unpack_tables_avx2; + constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; + + __m128i source = _mm_setzero_si128(); + std::memcpy(&source, input, packed_bytes); + + const __m128i shuffled = _mm_shuffle_epi8(source, unpack_table::get_shuffle()); + + constexpr u16 shift = 32 - BIT_WIDTH; + __m128i shifted = _mm_sllv_epi32(shuffled, unpack_table::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm_srai_epi32(shifted, shift); + } else { + shifted = _mm_srli_epi32(shifted, shift); + } + + return shifted; + } -/** + /** * @brief Unpack eight aligned 8-bit or 16-bit values directly from the input buffer. */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -__m256i mm256_unpack_aligned_epi32_avx2(const uint8_t* __restrict__ input) { - if constexpr (BIT_WIDTH == 8) { - const __m128i source = _mm_loadu_si64(input); - if constexpr (SIGN_VALUES) { - return _mm256_cvtepi8_epi32(source); - } else { - return _mm256_cvtepu8_epi32(source); + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + __m256i mm256_unpack_aligned_epi32_avx2(const u8 * __restrict__ input) { + if constexpr (BIT_WIDTH == 8) { + const __m128i source = _mm_loadu_si64(input); + if constexpr (SIGN_VALUES) { + return _mm256_cvtepi8_epi32(source); + } else { + return _mm256_cvtepu8_epi32(source); + } + } else { + const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); + if constexpr (SIGN_VALUES) { + return _mm256_cvtepi16_epi32(source); + } else { + return _mm256_cvtepu16_epi32(source); + } + } } - } else { - const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); - if constexpr (SIGN_VALUES) { - return _mm256_cvtepi16_epi32(source); - } else { - return _mm256_cvtepu16_epi32(source); - } - } -} -/** + /** * @brief Unpack eight values using the table-driven AVX2 shuffle path. */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -__m256i mm256_unpack_epi32_avx2(const uint8_t* __restrict__ input) { - using unpack_table = unpack_tables_avx2; - constexpr std::size_t packed_bytes = BIT_WIDTH; - - __m256i source = _mm256_setzero_si256(); - std::memcpy(&source, input, packed_bytes); - - const __m256i permuted = _mm256_permutevar8x32_epi32(source, unpack_table::get_permute()); - const __m256i shuffled = _mm256_shuffle_epi8(permuted, unpack_table::get_shuffle()); - - constexpr uint16_t shift = 32 - BIT_WIDTH; - __m256i shifted = _mm256_sllv_epi32(shuffled, unpack_table::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm256_srai_epi32(shifted, shift); - } else { - shifted = _mm256_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace internal + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + __m256i mm256_unpack_epi32_avx2(const u8 * __restrict__ input) { + using unpack_table = unpack_tables_avx2; + constexpr std::size_t packed_bytes = BIT_WIDTH; + + __m256i source = _mm256_setzero_si256(); + std::memcpy(&source, input, packed_bytes); + + const __m256i permuted = _mm256_permutevar8x32_epi32(source, unpack_table::get_permute()); + const __m256i shuffled = _mm256_shuffle_epi8(permuted, unpack_table::get_shuffle()); + + constexpr u16 shift = 32 - BIT_WIDTH; + __m256i shifted = _mm256_sllv_epi32(shuffled, unpack_table::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm256_srai_epi32(shifted, shift); + } else { + shifted = _mm256_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace internal -/** + /** * @brief Decompress a single block to float using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -190,37 +193,40 @@ __m256i mm256_unpack_epi32_avx2(const uint8_t* __restrict__ input) { * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_avx2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - const __m256 scale_v = _mm256_set1_ps(scale); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + const __m256 scale_v = _mm256_set1_ps(scale); #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); - const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); - _mm256_storeu_ps(output, dequantized); - input += BIT_WIDTH; - output += 8; - } + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); + const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); + _mm256_storeu_ps(output, dequantized); + input += BIT_WIDTH; + output += 8; + } - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi32(tail_values[i], scale); + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi32(tail_values[i], scale); + } } - } - return 0; -} + return 0; + } -/** + /** * @brief Decompress a single block to double using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -233,42 +239,45 @@ int mm256_decompress_block_avx2(const void* __restrict__ input_ptr, const float_ * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_avx2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - const __m256d scale_v = _mm256_set1_pd(scale); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + const __m256d scale_v = _mm256_set1_pd(scale); #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); - const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); - const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); + const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); + const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); - const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); - const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); + const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); + const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); - _mm256_storeu_pd(output, dequantized1); - _mm256_storeu_pd(output + 4, dequantized2); + _mm256_storeu_pd(output, dequantized1); + _mm256_storeu_pd(output + 4, dequantized2); - input += BIT_WIDTH; - output += 8; - } + input += BIT_WIDTH; + output += 8; + } - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi64(tail_values[i], scale); + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi64(tail_values[i], scale); + } } + return 0; } - return 0; -} -/** + /** * @brief Decompress multiple blocks to float using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -282,26 +291,27 @@ int mm256_decompress_block_avx2(const void* __restrict__ input_ptr, const double * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_avx2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_avx2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_avx2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } -/** + /** * @brief Decompress multiple blocks to double using AVX2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -315,24 +325,25 @@ int mm256_decompress_blocks_avx2(const void* __restrict__ input_ptr, const float * * @note This function requires AVX2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_avx2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_avx2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_avx2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } } // namespace pernix #endif // PERNIX_AVX2_DECOMPRESSION_H diff --git a/src/internal/pernix/x86/avx2/avx2_tables.h b/src/internal/pernix/x86/avx2/avx2_tables.h index e21af27..f62250f 100644 --- a/src/internal/pernix/x86/avx2/avx2_tables.h +++ b/src/internal/pernix/x86/avx2/avx2_tables.h @@ -7,13 +7,13 @@ #include namespace pernix::internal { - template<__uint8_t BIT_WIDTH, typename T> - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && (std::is_same_v || std::is_same_v)) - struct pack_tables_avx2_16 { - alignas(64) inline static constexpr std::array permute1 = [] { +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && (std::is_same_v || std::is_same_v)) +struct pack_tables_avx2_16 { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 0, 1, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, @@ -25,7 +25,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 0, 1, 4, 5, -1, -1, 10, 11, 14, 15, -1, -1, @@ -37,7 +37,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, @@ -49,7 +49,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, @@ -61,7 +61,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 0, 1, -1, -1, 6, 7, 8, 9, 10, 11, -1, -1, @@ -73,7 +73,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, @@ -85,7 +85,7 @@ namespace pernix::internal { 14, 15, -1, -1 }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, @@ -97,14 +97,14 @@ namespace pernix::internal { 14, 15, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, @@ -116,7 +116,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, @@ -128,7 +128,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 2, 3, -1, -1, 8, 9, -1, -1, 14, 15, -1, -1, @@ -140,7 +140,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, @@ -152,7 +152,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 2, 3, 4, 5, -1, -1, -1, -1, 12, 13, 14, 15, @@ -164,7 +164,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -176,7 +176,7 @@ namespace pernix::internal { 12, 13, -1, -1 }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -188,14 +188,14 @@ namespace pernix::internal { 12, 13, 14, 15 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ -1, -1, 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, @@ -207,7 +207,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ -1, -1, 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, @@ -219,7 +219,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ -1, -1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, @@ -231,7 +231,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ -1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, @@ -243,183 +243,183 @@ namespace pernix::internal { 14, 15, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 0, 2, 4, 6, 0, 1, 3, 5, 0, 2, 4, 6, 0, 1, 3, 5 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 0, 4, 0, 2, 6, 0, 4, 0, 0, 4, 0, 2, 6, 0, 4, 0 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 0, 6, 1, 7, 2, 0, 3, 0, 0, 6, 1, 7, 2, 0, 3, 0 }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 12, 8, 4, 12, 8, 4, 12, 8, 12, 8, 4, 12, 8, 4, 12, 8 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 0, 0, 7, 4, 1, 0, 0, 5, 0, 0, 7, 4, 1, 0, 0, 5 }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 14, 12, 10, 8, 6, 4, 2, 14, 14, 12, 10, 8, 6, 4, 2, 14 }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 15, 14, 13, 12, 11, 10, 9, 8, 15, 14, 13, 12, 11, 10, 9, 8 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 9, 11, 13, 15, 8, 10, 12, 14, 9, 11, 13, 15, 8, 10, 12, 14, }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 10, 14, 8, 12, 0, 10, 14, 8, 10, 14, 8, 12, 0, 10, 14, 8, }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 11, 8, 12, 8, 13, 8, 14, 9, 11, 8, 12, 8, 13, 8, 14, 9, }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 0, 4, 8, 0, 4, 8, 0, 4, 0, 4, 8, 0, 4, 8, 0, 4, }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 13, 10, 0, 0, 14, 11, 8, 0, 13, 10, 0, 0, 14, 11, 8, 0, }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 0, 2, 4, 6, 8, 10, 12, 0, 0, 2, 4, 6, 8, 10, 12, 0, }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 0, 7, 5, 3, 1, 8, 6, 4, 0, 7, 5, 3, 1, 8, 6, 4 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 0, 6, 2, 8, 4, 0, 6, 2, 0, 6, 2, 8, 4, 0, 6, 2, }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 0, 5, 10, 4, 9, 3, 8, 2, 0, 5, 10, 4, 9, 3, 8, 2, }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 0, 3, 6, 9, 12, 2, 5, 8, 0, 3, 6, 9, 12, 2, 5, 8, }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" __always_inline static T get_permute1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute1.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute1.data())); } + return T{}; + } __always_inline static T get_permute2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute2.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute2.data())); } + return T{}; + } __always_inline static T get_permute3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute3.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute3.data())); } + return T{}; + } __always_inline static T get_shift1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift1.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift1.data())); } + return T{}; + } __always_inline static T get_shift2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift2.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift2.data())); } + return T{}; + } __always_inline static T get_shift3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift3.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift3.data())); } + return T{}; + } #pragma GCC diagnostic pop - }; +}; - template<__uint8_t BIT_WIDTH, typename T> - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) - struct pack_tables_avx2_24 { - alignas(64) inline static constexpr std::array permute1 = [] { +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) +struct pack_tables_avx2_24 { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -454,11 +454,11 @@ namespace pernix::internal { 1, 2, 3, 5, 6, 7, -1, -1, }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -493,11 +493,11 @@ namespace pernix::internal { 0, 1, 2, 4, 5, 6, -1, -1, }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -528,185 +528,185 @@ namespace pernix::internal { 0, 1, 2, 4, 5, 6, 0, 0 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { - return std::array{ + return std::array{ 0, 2, 4, 6, 8, 10, 12, 14 }; } else if constexpr (BIT_WIDTH == 18) { - return std::array{ + return std::array{ 0, 4, 8, 12, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 19) { - return std::array{ + return std::array{ 0, 6, 12, 18, 5, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 20) { - return std::array{ + return std::array{ 0, 8, 16, 4, 12, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 21) { - return std::array{ + return std::array{ 0, 10, 20, 9, 19, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 22) { - return std::array{ + return std::array{ 0, 12, 2, 14, 4, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 23) { - return std::array{ + return std::array{ 0, 14, 5, 19, 10, 1, 32, 32 }; } else if constexpr (BIT_WIDTH == 24) { - return std::array{ + return std::array{ 24, 16, 8, 24, 16, 8, 24, 16 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { - return std::array{ + return std::array{ 17, 19, 21, 23, 25, 27, 29, 31 }; } else if constexpr (BIT_WIDTH == 18) { - return std::array{ + return std::array{ 18, 22, 26, 30, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 19) { - return std::array{ + return std::array{ 19, 25, 31, 32, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 20) { - return std::array{ + return std::array{ 20, 28, 32, 24, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 21) { - return std::array{ + return std::array{ 21, 31, 32, 30, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 22) { - return std::array{ + return std::array{ 22, 32, 24, 32, 26, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 23) { - return std::array{ + return std::array{ 23, 32, 28, 32, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 24) { - return std::array{ + return std::array{ 0, 8, 16, 0, 8, 16, 0, 8, }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { - return std::array{ + return std::array{ 0, 15, 13, 11, 9, 7, 5, 3 }; } else if constexpr (BIT_WIDTH == 18) { - return std::array{ + return std::array{ 32, 14, 10, 6, 2, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 19) { - return std::array{ + return std::array{ 32, 13, 7, 1, 14, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 20) { - return std::array{ + return std::array{ 32, 12, 4, 16, 8, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 21) { - return std::array{ + return std::array{ 32, 11, 1, 12, 2, 13, 32, 32 }; } else if constexpr (BIT_WIDTH == 22) { - return std::array{ + return std::array{ 32, 10, 20, 8, 18, 6, 32, 32 }; } else if constexpr (BIT_WIDTH == 23) { - return std::array{ + return std::array{ 32, 9, 18, 4, 13, 22, 32, 32 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" __always_inline static T get_permute1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute1.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute1.data())); } + return T{}; + } __always_inline static T get_permute2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute2.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute2.data())); } + return T{}; + } __always_inline static T get_permute3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(permute3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(permute3.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(permute3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(permute3.data())); } + return T{}; + } __always_inline static T get_shift1() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift1.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift1.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift1.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift1.data())); } + return T{}; + } __always_inline static T get_shift2() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift2.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift2.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift2.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift2.data())); } + return T{}; + } __always_inline static T get_shift3() { - if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(shift3.data())); - } else if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift3.data())); - } - return T{}; + if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(shift3.data())); + } else if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift3.data())); } + return T{}; + } #pragma GCC diagnostic pop - }; +}; - template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) - struct unpack_tables_avx2 { - alignas(32) inline static constexpr std::array permute = [] { +template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) +struct unpack_tables_avx2 { + alignas(32) inline static constexpr std::array permute = [] { // clang-format off if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return std::array{0, -1, -1, -1, 0, 1, -1, -1}; @@ -715,72 +715,72 @@ namespace pernix::internal { } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { return std::array{0, 1, 2, -1, 2, 3, 4, 5}; } - // clang-format on - }(); - - alignas(32) inline static constexpr std::array shuffle = [] { - std::array shuffles{}; - shuffles.fill(-1); - constexpr std::size_t rebase_second_half = 4 * ((BIT_WIDTH - 1) / 8); - - for (std::size_t lane = 0; lane < 2; ++lane) { - for (std::size_t i = 0; i < 4; ++i) { - const std::size_t value_index = lane * 4 + i; - - const std::size_t bit_start = value_index * BIT_WIDTH; - const std::size_t byte_start = bit_start / 8; - const std::size_t bit_offset = bit_start % 8; - const std::size_t byte_count = (bit_offset + BIT_WIDTH + 7) / 8; - - const std::size_t rebase = (lane == 0) ? 0 : rebase_second_half; - const std::size_t rel_byte_start = byte_start - rebase; - - const std::size_t dst = (lane * 4 + i) * 4; - for (std::size_t k = 0; k < byte_count; ++k) { - shuffles[dst + k] = static_cast(rel_byte_start + k); - } + // clang-format on + }(); + + alignas(32) inline static constexpr std::array shuffle = [] { + std::array shuffles{}; + shuffles.fill(-1); + constexpr std::size_t rebase_second_half = 4 * ((BIT_WIDTH - 1) / 8); + + for (std::size_t lane = 0; lane < 2; ++lane) { + for (std::size_t i = 0; i < 4; ++i) { + const std::size_t value_index = lane * 4 + i; + + const std::size_t bit_start = value_index * BIT_WIDTH; + const std::size_t byte_start = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t byte_count = (bit_offset + BIT_WIDTH + 7) / 8; + + const std::size_t rebase = (lane == 0) ? 0 : rebase_second_half; + const std::size_t rel_byte_start = byte_start - rebase; + + const std::size_t dst = (lane * 4 + i) * 4; + for (std::size_t k = 0; k < byte_count; ++k) { + shuffles[dst + k] = static_cast(rel_byte_start + k); } } + } - return shuffles; - }(); + return shuffles; + }(); - alignas(64) inline static constexpr std::array shift = [] { - std::array shifts{}; + alignas(64) inline static constexpr std::array shift = [] { + std::array shifts{}; - for (std::size_t lane = 0; lane < 8; ++lane) { - const int bit_offset = lane * BIT_WIDTH; - const int bit_in_byte = bit_offset % 8; - const int left_shift = 32 - BIT_WIDTH - bit_in_byte; - shifts[lane] = left_shift; - } + for (std::size_t lane = 0; lane < 8; ++lane) { + const int bit_offset = lane * BIT_WIDTH; + const int bit_in_byte = bit_offset % 8; + const int left_shift = 32 - BIT_WIDTH - bit_in_byte; + shifts[lane] = left_shift; + } - return shifts; - }(); + return shifts; + }(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" __always_inline static __m256i get_permute() { - return _mm256_load_si256(reinterpret_cast(permute.data())); - } + return _mm256_load_si256(reinterpret_cast(permute.data())); + } __always_inline static T get_shuffle() { - if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shuffle.data())); - } else { - return _mm256_load_si256(reinterpret_cast(shuffle.data())); - } + if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shuffle.data())); + } else { + return _mm256_load_si256(reinterpret_cast(shuffle.data())); } + } __always_inline static T get_shift() { - if constexpr (std::is_same_v) { - return _mm_load_si128(reinterpret_cast(shift.data())); - } else { - return _mm256_load_si256(reinterpret_cast(shift.data())); - } + if constexpr (std::is_same_v) { + return _mm_load_si128(reinterpret_cast(shift.data())); + } else { + return _mm256_load_si256(reinterpret_cast(shift.data())); } + } #pragma GCC diagnostic pop - }; +}; } // namespace pernix::internal #endif // PERNIX_AVX2_TABLES_H diff --git a/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h index e19051e..dab8472 100644 --- a/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h +++ b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h @@ -12,543 +12,604 @@ using namespace pernix::x86::internal; namespace pernix { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -static __always_inline __m512i mm512_clamp_signed_epi32(__m512i input) { - constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm512_min_epi32(_mm512_max_epi32(input, _mm512_set1_epi32(min_value)), _mm512_set1_epi32(max_value)); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -static __always_inline __m256i mm256_clamp_signed_epi32_avx512(__m256i input) { - constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), _mm256_set1_epi32(max_value)); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -static __always_inline __m512i mm512_clamp_signed_epi64(__m512i input) { - constexpr int64_t min_value = BIT_WIDTH == 1 ? 0 : -(int64_t{1} << (BIT_WIDTH - 1)); - constexpr int64_t max_value = BIT_WIDTH == 1 ? 1 : ((int64_t{1} << (BIT_WIDTH - 1)) - 1); - return _mm512_min_epi64(_mm512_max_epi64(input, _mm512_set1_epi64(min_value)), _mm512_set1_epi64(max_value)); -} - -/** - * @brief Quantize sixteen float values to 32-bit integers. - */ - -static __always_inline __m512i mm512_quantize_ps_epi32(const __m512& input, const __m512& scale) { - const __m512 scaled = _mm512_mul_ps(input, scale); - return _mm512_cvtps_epi32(scaled); -} - -static __always_inline __m512i mm512_quantize_pd_epi64(const __m512d& input, const __m512d& scale) { - const __m512d scaled = _mm512_mul_pd(input, scale); - return _mm512_cvtpd_epi64(scaled); -} - -static __always_inline __m256i mm512_quantize_pd_epi32(const __m512d& input, const __m512d& scale) { - const __m512d scaled = _mm512_mul_pd(input, scale); - return _mm512_cvtpd_epi32(scaled); -} - -static __always_inline __m512i make_m512i_from_2x256(const __m256i a, const __m256i b) { - __m512i result = _mm512_castsi256_si512(a); - result = _mm512_inserti64x4(result, b, 1); - return result; -} - -static __always_inline __m512i make_m512i_from_4x128(const __m128i a, const __m128i b, const __m128i c, const __m128i d) { - __m512i result = _mm512_castsi128_si512(a); - result = _mm512_inserti64x2(result, b, 1); - result = _mm512_inserti64x2(result, c, 2); - result = _mm512_inserti64x2(result, d, 3); - return result; -} - -static __always_inline __m512i make_m512i_from_8x64(const __m128i a, const __m128i b, const __m128i c, const __m128i d, const __m128i e, - const __m128i f, const __m128i g, const __m128i h) { - const __m128i ab = _mm_unpacklo_epi64(a, b); - const __m128i cd = _mm_unpacklo_epi64(c, d); - const __m128i ef = _mm_unpacklo_epi64(e, f); - const __m128i gh = _mm_unpacklo_epi64(g, h); - - __m512i x = _mm512_castsi128_si512(ab); - x = _mm512_inserti32x4(x, cd, 1); - x = _mm512_inserti32x4(x, ef, 2); - x = _mm512_inserti32x4(x, gh, 3); - return x; -} - -static __always_inline __m256i make_m256i_from_2x128(const __m128i a, const __m128i b) { - __m256i result = _mm256_castsi128_si256(a); - result = _mm256_inserti128_si256(result, b, 1); - return result; -} - -static __always_inline __m256i make_m256i_from_4x64(const __m128i a, const __m128i b, const __m128i c, const __m128i d) { - const __m128i ab = _mm_unpacklo_epi64(a, b); - const __m128i cd = _mm_unpacklo_epi64(c, d); - - __m256i x = _mm256_castsi128_si256(ab); - x = _mm256_inserti128_si256(x, cd, 1); - return x; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_64 = elements_per_block / 64; - constexpr uint32_t iterations_32 = (elements_per_block % 64) / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512 scale_v = _mm512_set1_ps(scale); - - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_64; ++iter) { - const __m512 source1 = _mm512_loadu_ps(input); - const __m512 source2 = _mm512_loadu_ps(input + 16); - const __m512 source3 = _mm512_loadu_ps(input + 32); - const __m512 source4 = _mm512_loadu_ps(input + 48); + namespace internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + static __always_inline __m512i mm512_clamp_signed_epi32(__m512i input) { + constexpr i32 min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); + return _mm512_min_epi32(_mm512_max_epi32(input, _mm512_set1_epi32(min_value)), + _mm512_set1_epi32(max_value)); + } - const __m512i quantized1 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source4, scale_v)); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + static __always_inline __m256i mm256_clamp_signed_epi32_avx512(__m256i input) { + constexpr i32 min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); + return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), + _mm256_set1_epi32(max_value)); + } - const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); - const __m128i converted3 = _mm512_cvtepi32_epi8(quantized3); - const __m128i converted4 = _mm512_cvtepi32_epi8(quantized4); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + static __always_inline __m512i mm512_clamp_signed_epi64(__m512i input) { + constexpr i64 min_value = BIT_WIDTH == 1 ? 0 : -(i64{1} << (BIT_WIDTH - 1)); + constexpr i64 max_value = BIT_WIDTH == 1 ? 1 : ((i64{1} << (BIT_WIDTH - 1)) - 1); + return _mm512_min_epi64(_mm512_max_epi64(input, _mm512_set1_epi64(min_value)), + _mm512_set1_epi64(max_value)); + } - const __m512i packed = - m512::mm512_pack_epi8_avx512vbmi_1to8(make_m512i_from_4x128(converted1, converted2, converted3, converted4)); + /** + * @brief Quantize sixteen float values to 32-bit integers. + */ - mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); + static __always_inline __m512i mm512_quantize_ps_epi32(const __m512 &input, const __m512 &scale) { + const __m512 scaled = _mm512_mul_ps(input, scale); + return _mm512_cvtps_epi32(scaled); + } - input += 64; - output += 8 * BIT_WIDTH; + static __always_inline __m512i mm512_quantize_pd_epi64(const __m512d &input, const __m512d &scale) { + const __m512d scaled = _mm512_mul_pd(input, scale); + return _mm512_cvtpd_epi64(scaled); } - } - if constexpr (iterations_32 > 0) { - const __m512 source1 = _mm512_loadu_ps(input); - const __m512 source2 = _mm512_loadu_ps(input + 16); + static __always_inline __m256i mm512_quantize_pd_epi32(const __m512d &input, const __m512d &scale) { + const __m512d scaled = _mm512_mul_pd(input, scale); + return _mm512_cvtpd_epi32(scaled); + } - const __m512i quantized1 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source2, scale_v)); + static __always_inline __m512i make_m512i_from_2x256(const __m256i a, const __m256i b) { + __m512i result = _mm512_castsi256_si512(a); + result = _mm512_inserti64x4(result, b, 1); + return result; + } - const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); + static __always_inline __m512i make_m512i_from_4x128(const __m128i a, const __m128i b, const __m128i c, + const __m128i d) { + __m512i result = _mm512_castsi128_si512(a); + result = _mm512_inserti64x2(result, b, 1); + result = _mm512_inserti64x2(result, c, 2); + result = _mm512_inserti64x2(result, d, 3); + return result; + } - const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_2x128(converted1, converted2)); - mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); + static __always_inline __m512i make_m512i_from_8x64(const __m128i a, const __m128i b, const __m128i c, + const __m128i d, const __m128i e, + const __m128i f, const __m128i g, const __m128i h) { + const __m128i ab = _mm_unpacklo_epi64(a, b); + const __m128i cd = _mm_unpacklo_epi64(c, d); + const __m128i ef = _mm_unpacklo_epi64(e, f); + const __m128i gh = _mm_unpacklo_epi64(g, h); + + __m512i x = _mm512_castsi128_si512(ab); + x = _mm512_inserti32x4(x, cd, 1); + x = _mm512_inserti32x4(x, ef, 2); + x = _mm512_inserti32x4(x, gh, 3); + return x; + } - input += 32; - output += 4 * BIT_WIDTH; - } + static __always_inline __m256i make_m256i_from_2x128(const __m128i a, const __m128i b) { + __m256i result = _mm256_castsi128_si256(a); + result = _mm256_inserti128_si256(result, b, 1); + return result; + } - if constexpr (iterations_16 > 0) { - const __m512 source = _mm512_loadu_ps(input); - const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - const __m128i converted = _mm512_cvtepi32_epi8(quantized); + static __always_inline __m256i make_m256i_from_4x64(const __m128i a, const __m128i b, const __m128i c, + const __m128i d) { + const __m128i ab = _mm_unpacklo_epi64(a, b); + const __m128i cd = _mm_unpacklo_epi64(c, d); - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); - mm_storeu_elements_epi16(output, BIT_WIDTH, packed); + __m256i x = _mm256_castsi128_si256(ab); + x = _mm256_inserti128_si256(x, cd, 1); + return x; + } - input += 16; - output += 2 * BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_1to8(const f32 * __restrict__ input, const f32 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - if constexpr (remaining_elements > 0) { - const __m512 source = mm512_loadu_elements_ps(remaining_elements, input); - const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - const __m128i converted = _mm512_cvtepi32_epi8(quantized); - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); + constexpr u32 iterations_64 = elements_per_block / 64; + constexpr u32 iterations_32 = (elements_per_block % 64) / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 remaining_elements = + elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } + const __m512 scale_v = _mm512_set1_ps(scale); - return 0; -} + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 iter = 0; iter < iterations_64; ++iter) { + const __m512 source1 = _mm512_loadu_ps(input); + const __m512 source2 = _mm512_loadu_ps(input + 16); + const __m512 source3 = _mm512_loadu_ps(input + 32); + const __m512 source4 = _mm512_loadu_ps(input + 48); + + const __m512i quantized1 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source4, scale_v)); + + const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); + const __m128i converted3 = _mm512_cvtepi32_epi8(quantized3); + const __m128i converted4 = _mm512_cvtepi32_epi8(quantized4); + + const __m512i packed = + m512::mm512_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (make_m512i_from_4x128( + converted1, converted2, converted3, converted4)); + + mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); + + input += 64; + output += 8 * BIT_WIDTH; + } + } + + if constexpr (iterations_32 > 0) { + const __m512 source1 = _mm512_loadu_ps(input); + const __m512 source2 = _mm512_loadu_ps(input + 16); + + const __m512i quantized1 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); + + const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (make_m256i_from_2x128( + converted1, converted2)); + mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m512 source = _mm512_loadu_ps(input); + const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); + const __m128i converted = _mm512_cvtepi32_epi8(quantized); + + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (converted); + mm_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m512 source = mm512_loadu_elements_ps(remaining_elements, input); + const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); + const __m128i converted = _mm512_cvtepi32_epi8(quantized); + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (converted); + + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_9to16(const f32 * __restrict__ input, const f32 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); - if constexpr (iterations_32 > 0) { + if constexpr (iterations_32 > 0) { #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_32; ++iter) { - const __m512 source1 = _mm512_loadu_ps(input); - const __m512 source2 = _mm512_loadu_ps(input + 16); - - const __m512i quantized1 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source2, scale_v)); - - const __m256i converted1 = _mm512_cvtepi32_epi16(quantized1); - const __m256i converted2 = _mm512_cvtepi32_epi16(quantized2); - - const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_2x256(converted1, converted2)); - mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; + for (u32 iter = 0; iter < iterations_32; ++iter) { + const __m512 source1 = _mm512_loadu_ps(input); + const __m512 source2 = _mm512_loadu_ps(input + 16); + + const __m512i quantized1 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source2, scale_v)); + + const __m256i converted1 = _mm512_cvtepi32_epi16(quantized1); + const __m256i converted2 = _mm512_cvtepi32_epi16(quantized2); + + const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (make_m512i_from_2x256( + converted1, converted2)); + mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + } + + if constexpr (iterations_16 > 0) { + const __m512 source = _mm512_loadu_ps(input); + const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); + const __m256i converted = _mm512_cvtepi32_epi16(quantized); + + const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (iterations_8 > 0) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + const __m128i converted = _mm256_cvtepi32_epi16(quantized); + + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + const __m128i converted = _mm256_cvtepi32_epi16(quantized); + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; } - } - if constexpr (iterations_16 > 0) { - const __m512 source = _mm512_loadu_ps(input); - const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - const __m256i converted = _mm512_cvtepi32_epi16(quantized); + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_17to24(const f32 * __restrict__ input, const f32 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(converted); - mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - input += 16; - output += 2 * BIT_WIDTH; - } + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); - if constexpr (iterations_8 > 0) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m128i converted = _mm256_cvtepi32_epi16(quantized); - - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m128i converted = _mm256_cvtepi32_epi16(quantized); - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); - - if constexpr (iterations_16 > 0) { + if constexpr (iterations_16 > 0) { #pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512 source = _mm512_loadu_ps(input); - const __m512i packed_input = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - - const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(packed_input); - mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; - output += 2 * BIT_WIDTH; + for (u32 i = 0; i < iterations_16; ++i) { + const __m512 source = _mm512_loadu_ps(input); + const __m512i packed_input = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source, scale_v)); + + const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (packed_input); + mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); + input += 16; + output += 2 * BIT_WIDTH; + } + } + + if constexpr (iterations_8 > 0) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i packed_input = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (packed_input); + mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); + const __m256i packed_input = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (packed_input); + + mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; } - } - - if constexpr (iterations_8 > 0) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); - mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_1to8(const f64 * __restrict__ input, const f64 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - input += 8; - output += BIT_WIDTH; - } + constexpr u32 iterations_64 = elements_per_block / 64; + constexpr u32 iterations_32 = (elements_per_block % 64) / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 remaining_elements = + elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - if constexpr (remaining_elements > 0) { - const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); - const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); + const __m512d scale_v = _mm512_set1_pd(scale); - mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_64 = elements_per_block / 64; - constexpr uint32_t iterations_32 = (elements_per_block % 64) / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_64 > 0) { + if constexpr (iterations_64 > 0) { #pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_64; ++iter) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512d source3 = _mm512_loadu_pd(input + 16); - const __m512d source4 = _mm512_loadu_pd(input + 24); - const __m512d source5 = _mm512_loadu_pd(input + 32); - const __m512d source6 = _mm512_loadu_pd(input + 40); - const __m512d source7 = _mm512_loadu_pd(input + 48); - const __m512d source8 = _mm512_loadu_pd(input + 56); - - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source4, scale_v)); - const __m512i quantized5 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source5, scale_v)); - const __m512i quantized6 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source6, scale_v)); - const __m512i quantized7 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source7, scale_v)); - const __m512i quantized8 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source8, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); - const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); - const __m128i converted5 = _mm512_cvtepi64_epi8(quantized5); - const __m128i converted6 = _mm512_cvtepi64_epi8(quantized6); - const __m128i converted7 = _mm512_cvtepi64_epi8(quantized7); - const __m128i converted8 = _mm512_cvtepi64_epi8(quantized8); - - const __m512i packed = m512::mm512_pack_epi8_avx512vbmi_1to8( - make_m512i_from_8x64(converted1, converted2, converted3, converted4, converted5, converted6, converted7, converted8)); - - mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - - input += 64; - output += 8 * BIT_WIDTH; + for (u32 iter = 0; iter < iterations_64; ++iter) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512d source3 = _mm512_loadu_pd(input + 16); + const __m512d source4 = _mm512_loadu_pd(input + 24); + const __m512d source5 = _mm512_loadu_pd(input + 32); + const __m512d source6 = _mm512_loadu_pd(input + 40); + const __m512d source7 = _mm512_loadu_pd(input + 48); + const __m512d source8 = _mm512_loadu_pd(input + 56); + + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source4, scale_v)); + const __m512i quantized5 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source5, scale_v)); + const __m512i quantized6 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source6, scale_v)); + const __m512i quantized7 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source7, scale_v)); + const __m512i quantized8 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source8, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); + const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); + const __m128i converted5 = _mm512_cvtepi64_epi8(quantized5); + const __m128i converted6 = _mm512_cvtepi64_epi8(quantized6); + const __m128i converted7 = _mm512_cvtepi64_epi8(quantized7); + const __m128i converted8 = _mm512_cvtepi64_epi8(quantized8); + + const __m512i packed = m512::mm512_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > ( + make_m512i_from_8x64(converted1, converted2, converted3, converted4, + converted5, converted6, converted7, converted8)); + + mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); + + input += 64; + output += 8 * BIT_WIDTH; + } + } + + if constexpr (iterations_32 > 0) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512d source3 = _mm512_loadu_pd(input + 16); + const __m512d source4 = _mm512_loadu_pd(input + 24); + + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source4, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); + const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); + + const __m256i packed = + m256::mm256_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (make_m256i_from_4x64( + converted1, converted2, converted3, converted4)); + + mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (_mm_unpacklo_epi64( + converted1, converted2)); + + mm_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + constexpr u32 source1_elements = remaining_elements > 8 ? 8 : remaining_elements; + constexpr u32 source2_elements = remaining_elements > 8 ? remaining_elements - 8 : 0; + + const __m512d source1 = mm512_loadu_elements_pd(source1_elements, input); + const __m512d source2 = source2_elements > 0 + ? mm512_loadu_elements_pd(source2_elements, input + 8) + : _mm512_setzero_pd(); + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (_mm_unpacklo_epi64( + converted1, converted2)); + + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; } - } - if constexpr (iterations_32 > 0) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512d source3 = _mm512_loadu_pd(input + 16); - const __m512d source4 = _mm512_loadu_pd(input + 24); - - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source4, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); - const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); - - const __m256i packed = - m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_4x64(converted1, converted2, converted3, converted4)); - - mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_9to16(const f64 * __restrict__ input, const f64 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - if constexpr (iterations_16 > 0) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + const __m512d scale_v = _mm512_set1_pd(scale); - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(_mm_unpacklo_epi64(converted1, converted2)); - - mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - constexpr uint32_t source1_elements = remaining_elements > 8 ? 8 : remaining_elements; - constexpr uint32_t source2_elements = remaining_elements > 8 ? remaining_elements - 8 : 0; - - const __m512d source1 = mm512_loadu_elements_pd(source1_elements, input); - const __m512d source2 = source2_elements > 0 ? mm512_loadu_elements_pd(source2_elements, input + 8) : _mm512_setzero_pd(); - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(_mm_unpacklo_epi64(converted1, converted2)); - - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_32 > 0) { + if constexpr (iterations_32 > 0) { #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_32; ++iter) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512d source3 = _mm512_loadu_pd(input + 16); - const __m512d source4 = _mm512_loadu_pd(input + 24); - - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source4, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); - const __m128i converted3 = _mm512_cvtepi64_epi16(quantized3); - const __m128i converted4 = _mm512_cvtepi64_epi16(quantized4); - - const __m512i packed = - m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_4x128(converted1, converted2, converted3, converted4)); - - mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; + for (u32 iter = 0; iter < iterations_32; ++iter) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512d source3 = _mm512_loadu_pd(input + 16); + const __m512d source4 = _mm512_loadu_pd(input + 24); + + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source4, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); + const __m128i converted3 = _mm512_cvtepi64_epi16(quantized3); + const __m128i converted4 = _mm512_cvtepi64_epi16(quantized4); + + const __m512i packed = + m512::mm512_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (make_m512i_from_4x128( + converted1, converted2, converted3, converted4)); + + mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + } + + if constexpr (iterations_16 > 0) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); + + const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (make_m256i_from_2x128( + converted1, converted2)); + + mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (iterations_8 > 0) { + const __m512d source = _mm512_loadu_pd(input); + const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); + const __m128i converted = _mm512_cvtepi64_epi16(quantized); + + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); + const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); + const __m128i converted = _mm512_cvtepi64_epi16(quantized); + + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; } - } - - if constexpr (iterations_16 > 0) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); - - const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(make_m256i_from_2x128(converted1, converted2)); - - mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - - if constexpr (iterations_8 > 0) { - const __m512d source = _mm512_loadu_pd(input); - const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); - const __m128i converted = _mm512_cvtepi64_epi16(quantized); - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - mm_storeu_elements_epi8(output, BIT_WIDTH, packed); + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_17to24(const f64 * __restrict__ input, const f64 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); - const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); - const __m128i converted = _mm512_cvtepi64_epi16(quantized); - - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - const __m512d scale_v = _mm512_set1_pd(scale); + const __m512d scale_v = _mm512_set1_pd(scale); - if constexpr (iterations_16 > 0) { + if constexpr (iterations_16 > 0) { #pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - - const __m256i quantized1 = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source1, scale_v)); - const __m256i quantized2 = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source2, scale_v)); - - const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(make_m512i_from_2x256(quantized1, quantized2)); - mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; + for (u32 i = 0; i < iterations_16; ++i) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + + const __m256i quantized1 = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source1, scale_v)); + const __m256i quantized2 = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source2, scale_v)); + + const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > ( + make_m512i_from_2x256(quantized1, quantized2)); + mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + } + + if constexpr (iterations_8 > 0) { + const __m512d source = _mm512_loadu_pd(input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source, scale_v)); + + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (quantized); + mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source, scale_v)); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (quantized); + + mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; } - } + } // namespace internal - if constexpr (iterations_8 > 0) { - const __m512d source = _mm512_loadu_pd(input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source, scale_v)); - - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); - mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source, scale_v)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); - - mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} -} // namespace internal - -/** + /** * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -560,24 +621,25 @@ __always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __res * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_block_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - std::memset(output, 0, BLOCK_SIZE); - - if constexpr (BIT_WIDTH <= 8) { - return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH <= 16) { - return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); - } else { - return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_block_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + std::memset(output, 0, BLOCK_SIZE); + + if constexpr (BIT_WIDTH <= 8) { + return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH <= 16) { + return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); + } else { + return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); + } } -} -/** + /** * @brief Compress a single block of double values using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -588,24 +650,25 @@ int mm512_compress_block_avx512vbmi(const void* __restrict__ input_ptr, const fl * * @note This overload is declared for parity with the float path. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_block_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - std::memset(output, 0, BLOCK_SIZE); - - if constexpr (BIT_WIDTH <= 8) { - return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH <= 16) { - return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); - } else { - return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_block_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + std::memset(output, 0, BLOCK_SIZE); + + if constexpr (BIT_WIDTH <= 8) { + return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH <= 16) { + return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); + } else { + return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); + } } -} -/** + /** * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -618,26 +681,27 @@ int mm512_compress_block_avx512vbmi(const void* __restrict__ input_ptr, const do * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_compress_block_avx512vbmi(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } -/** + /** * @brief Compress multiple blocks of double values using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -647,24 +711,25 @@ int mm512_compress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const f * @param blocks number of blocks to compress. * @return int status code. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_compress_block_avx512vbmi(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } } // namespace pernix #endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h index d0f98f0..a004be5 100644 --- a/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h +++ b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h @@ -12,493 +12,541 @@ using namespace pernix::x86::internal; namespace pernix { -namespace internal { -/** + namespace internal { + /** * @brief Dequantize sixteen integer values to floats. */ -__always_inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { - const __m512 converted = _mm512_cvtepi32_ps(input); - return _mm512_mul_ps(converted, scale); -} - -__always_inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { - const __m512d converted = _mm512_cvtepi64_pd(input); - return _mm512_mul_pd(converted, scale); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const uint32_t iterations_64 = elements_per_block / 64; - const uint32_t iterations_32 = (elements_per_block % 64) / 32; - const uint32_t iterations_16 = (elements_per_block % 32) / 16; - const uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512 scale_v = _mm512_set1_ps(scale); - - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t i = 0; i < iterations_64; ++i) { - const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted1 = _mm512_cvtepi8_epi32(_mm512_castsi512_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 1)); - const __m512i converted3 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 2)); - const __m512i converted4 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 3)); - - const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); - const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - const __m512 dequantized3 = mm512_dequantize_epi32(converted3, scale_v); - const __m512 dequantized4 = mm512_dequantize_epi32(converted4, scale_v); - - _mm512_storeu_ps(output, dequantized1); - _mm512_storeu_ps(output + 16, dequantized2); - _mm512_storeu_ps(output + 32, dequantized3); - _mm512_storeu_ps(output + 48, dequantized4); - - output += 64; - input += 8 * BIT_WIDTH; +__always_inline __m512 mm512_dequantize_epi32(const __m512i &input, const __m512 &scale) { + const __m512 converted = _mm512_cvtepi32_ps(input); + return _mm512_mul_ps(converted, scale); } - } - - if constexpr (iterations_32 > 0) { - const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted1 = _mm512_cvtepi8_epi32(_mm256_castsi256_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi8_epi32(_mm256_extracti128_si256(unpacked, 1)); - - const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); - const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - - _mm512_storeu_ps(output, dequantized1); - _mm512_storeu_ps(output + 16, dequantized2); - - output += 32; - input += 4 * BIT_WIDTH; - } - - if constexpr (iterations_16 > 0) { - const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); - const __m512i converted = _mm512_cvtepi8_epi32(unpacked); - - const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - - _mm512_storeu_ps(output, dequantized); - - output += 16; - input += 2 * BIT_WIDTH; - } +__always_inline __m512d mm512_dequantize_epi64(const __m512i &input, const __m512d &scale) { + const __m512d converted = _mm512_cvtepi64_pd(input); + return _mm512_mul_pd(converted, scale); + } - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - const __m512i converted = _mm512_cvtepi8_epi32(unpacked); + const u32 iterations_64 = elements_per_block / 64; + const u32 iterations_32 = (elements_per_block % 64) / 32; + const u32 iterations_16 = (elements_per_block % 32) / 16; + const u32 remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - + iterations_16 * 16; - const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); + const __m512 scale_v = _mm512_set1_ps(scale); - mm512_storeu_elements_ps(output, remaining_elements, dequantized); - } + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 i = 0; i < iterations_64; ++i) { + const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi32(_mm512_castsi512_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 1)); + const __m512i converted3 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 2)); + const __m512i converted4 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 3)); + + const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); + const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); + const __m512 dequantized3 = mm512_dequantize_epi32(converted3, scale_v); + const __m512 dequantized4 = mm512_dequantize_epi32(converted4, scale_v); + + _mm512_storeu_ps(output, dequantized1); + _mm512_storeu_ps(output + 16, dequantized2); + _mm512_storeu_ps(output + 32, dequantized3); + _mm512_storeu_ps(output + 48, dequantized4); + + output += 64; + input += 8 * BIT_WIDTH; + } + } - return 0; -} + if constexpr (iterations_32 > 0) { + const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + const __m512i converted1 = _mm512_cvtepi8_epi32(_mm256_castsi256_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi8_epi32(_mm256_extracti128_si256(unpacked, 1)); - const uint32_t iterations_64 = elements_per_block / 64; - const uint32_t iterations_32 = (elements_per_block % 64) / 32; - const uint32_t iterations_16 = (elements_per_block % 32) / 16; - const uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; + const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); + const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - const __m512d scale_v = _mm512_set1_pd(scale); + _mm512_storeu_ps(output, dequantized1); + _mm512_storeu_ps(output + 16, dequantized2); - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t i = 0; i < iterations_64; ++i) { - const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8(source); - - const __m128i extracted1 = _mm512_castsi512_si128(unpacked); - const __m128i extracted2 = _mm512_extracti64x2_epi64(unpacked, 1); - const __m128i extracted3 = _mm512_extracti64x2_epi64(unpacked, 2); - const __m128i extracted4 = _mm512_extracti64x2_epi64(unpacked, 3); - - const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); - const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); - const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); - const __m512i converted5 = _mm512_cvtepi8_epi64(extracted3); - const __m512i converted6 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted3, 8)); - const __m512i converted7 = _mm512_cvtepi8_epi64(extracted4); - const __m512i converted8 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted4, 8)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); - const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); - const __m512d dequantized5 = mm512_dequantize_epi64(converted5, scale_v); - const __m512d dequantized6 = mm512_dequantize_epi64(converted6, scale_v); - const __m512d dequantized7 = mm512_dequantize_epi64(converted7, scale_v); - const __m512d dequantized8 = mm512_dequantize_epi64(converted8, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - _mm512_storeu_pd(output + 16, dequantized3); - _mm512_storeu_pd(output + 24, dequantized4); - _mm512_storeu_pd(output + 32, dequantized5); - _mm512_storeu_pd(output + 40, dequantized6); - _mm512_storeu_pd(output + 48, dequantized7); - _mm512_storeu_pd(output + 56, dequantized8); - - output += 64; - input += 8 * BIT_WIDTH; - } - - if constexpr (iterations_32 > 0) { - const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8(source); + output += 32; + input += 4 * BIT_WIDTH; + } - const __m128i extracted1 = _mm256_castsi256_si128(unpacked); - const __m128i extracted2 = _mm256_extracti64x2_epi64(unpacked, 1); + if constexpr (iterations_16 > 0) { + const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); - const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); - const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); + const __m512i converted = _mm512_cvtepi8_epi32(unpacked); - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); - const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - _mm512_storeu_pd(output + 16, dequantized3); - _mm512_storeu_pd(output + 24, dequantized4); + _mm512_storeu_ps(output, dequantized); - output += 32; - input += 4 * BIT_WIDTH; - } + output += 16; + input += 2 * BIT_WIDTH; + } - if constexpr (iterations_16 > 0) { - const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); + const __m512i converted = _mm512_cvtepi8_epi32(unpacked); - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); + mm512_storeu_elements_ps(output, remaining_elements, dequantized); + } - output += 16; - input += 2 * BIT_WIDTH; + return 0; } - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); + const u32 iterations_64 = elements_per_block / 64; + const u32 iterations_32 = (elements_per_block % 64) / 32; + const u32 iterations_16 = (elements_per_block % 32) / 16; + const u32 remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - + iterations_16 * 16; - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d scale_v = _mm512_set1_pd(scale); - mm512_storeu_elements_pd(output, remaining_elements < 8 ? remaining_elements : 8, dequantized1); - if constexpr (remaining_elements > 8) { - mm512_storeu_elements_pd(output + 8, remaining_elements - 8, dequantized2); + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 i = 0; i < iterations_64; ++i) { + const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m128i extracted1 = _mm512_castsi512_si128(unpacked); + const __m128i extracted2 = _mm512_extracti64x2_epi64(unpacked, 1); + const __m128i extracted3 = _mm512_extracti64x2_epi64(unpacked, 2); + const __m128i extracted4 = _mm512_extracti64x2_epi64(unpacked, 3); + + const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); + const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); + const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); + const __m512i converted5 = _mm512_cvtepi8_epi64(extracted3); + const __m512i converted6 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted3, 8)); + const __m512i converted7 = _mm512_cvtepi8_epi64(extracted4); + const __m512i converted8 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted4, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); + const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + const __m512d dequantized5 = mm512_dequantize_epi64(converted5, scale_v); + const __m512d dequantized6 = mm512_dequantize_epi64(converted6, scale_v); + const __m512d dequantized7 = mm512_dequantize_epi64(converted7, scale_v); + const __m512d dequantized8 = mm512_dequantize_epi64(converted8, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output + 16, dequantized3); + _mm512_storeu_pd(output + 24, dequantized4); + _mm512_storeu_pd(output + 32, dequantized5); + _mm512_storeu_pd(output + 40, dequantized6); + _mm512_storeu_pd(output + 48, dequantized7); + _mm512_storeu_pd(output + 56, dequantized8); + + output += 64; + input += 8 * BIT_WIDTH; + } + + if constexpr (iterations_32 > 0) { + const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m128i extracted1 = _mm256_castsi256_si128(unpacked); + const __m128i extracted2 = _mm256_extracti64x2_epi64(unpacked, 1); + + const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); + const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); + const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); + const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output + 16, dequantized3); + _mm512_storeu_pd(output + 24, dequantized4); + + output += 32; + input += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + + output += 16; + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + + mm512_storeu_elements_pd(output, remaining_elements < 8 ? remaining_elements : 8, dequantized1); + if constexpr (remaining_elements > 8) { + mm512_storeu_elements_pd(output + 8, remaining_elements - 8, dequantized2); + } + } } - } - } - return 0; -} + return 0; + } -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); - if constexpr (iterations_32 > 0) { + if constexpr (iterations_32 > 0) { #pragma GCC unroll 4 - for (uint32_t i = 0; i < iterations_32; ++i) { - const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16(source); + for (u32 i = 0; i < iterations_32; ++i) { + const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted1 = _mm512_cvtepi16_epi32(_mm512_castsi512_si256(unpacked)); - const __m512i converted2 = _mm512_cvtepi16_epi32(_mm512_extracti32x8_epi32(unpacked, 1)); + const __m512i converted1 = _mm512_cvtepi16_epi32(_mm512_castsi512_si256(unpacked)); + const __m512i converted2 = _mm512_cvtepi16_epi32(_mm512_extracti32x8_epi32(unpacked, 1)); - const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); - const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); + const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); + const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - _mm512_storeu_ps(output, dequantized1); - _mm512_storeu_ps(output + 16, dequantized2); + _mm512_storeu_ps(output, dequantized1); + _mm512_storeu_ps(output + 16, dequantized2); - output += 32; - input += 4 * BIT_WIDTH; - } - } + output += 32; + input += 4 * BIT_WIDTH; + } + } - if constexpr (iterations_16 > 0) { - const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16(source); + if constexpr (iterations_16 > 0) { + const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted = _mm512_cvtepi16_epi32(unpacked); - const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); + const __m512i converted = _mm512_cvtepi16_epi32(unpacked); + const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - _mm512_storeu_ps(output, dequantized); + _mm512_storeu_ps(output, dequantized); - output += 16; - input += 2 * BIT_WIDTH; - } + output += 16; + input += 2 * BIT_WIDTH; + } - if constexpr (iterations_8 > 0) { - const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); + if constexpr (iterations_8 > 0) { + const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m256i converted = _mm256_cvtepi16_epi32(unpacked); - const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); + const __m256i converted = _mm256_cvtepi16_epi32(unpacked); + const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); - _mm256_storeu_ps(output, dequantized); + _mm256_storeu_ps(output, dequantized); - output += 8; - input += BIT_WIDTH; - } + output += 8; + input += BIT_WIDTH; + } - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m256i converted = _mm256_cvtepi16_epi32(unpacked); - const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); + const __m256i converted = _mm256_cvtepi16_epi32(unpacked); + const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); - mm256_storeu_elements_ps(output, remaining_elements, dequantized); - } + mm256_storeu_elements_ps(output, remaining_elements, dequantized); + } - return 0; -} + return 0; + } -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - const __m512d scale_v = _mm512_set1_pd(scale); + const __m512d scale_v = _mm512_set1_pd(scale); - if constexpr (iterations_32 > 0) { + if constexpr (iterations_32 > 0) { #pragma GCC unroll 4 - for (uint32_t i = 0; i < iterations_32; ++i) { - const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted1 = _mm512_cvtepi16_epi64(_mm512_castsi512_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 1)); - const __m512i converted3 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 2)); - const __m512i converted4 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 3)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); - const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - _mm512_storeu_pd(output + 16, dequantized3); - _mm512_storeu_pd(output + 24, dequantized4); - - output += 32; - input += 4 * BIT_WIDTH; - } - } + for (u32 i = 0; i < iterations_32; ++i) { + const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi16_epi64(_mm512_castsi512_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 1)); + const __m512i converted3 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 2)); + const __m512i converted4 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 3)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); + const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output + 16, dequantized3); + _mm512_storeu_pd(output + 24, dequantized4); + + output += 32; + input += 4 * BIT_WIDTH; + } + } - if constexpr (iterations_16 > 0) { - const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16(source); + if constexpr (iterations_16 > 0) { + const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted1 = _mm512_cvtepi16_epi64(_mm256_castsi256_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi16_epi64(_mm256_extracti64x2_epi64(unpacked, 1)); + const __m512i converted1 = _mm512_cvtepi16_epi64(_mm256_castsi256_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi16_epi64(_mm256_extracti64x2_epi64(unpacked, 1)); - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); - output += 16; - input += 2 * BIT_WIDTH; - } + output += 16; + input += 2 * BIT_WIDTH; + } - if constexpr (iterations_8 > 0) { - const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); + if constexpr (iterations_8 > 0) { + const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted = _mm512_cvtepi16_epi64(unpacked); + const __m512i converted = _mm512_cvtepi16_epi64(unpacked); - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - _mm512_storeu_pd(output, dequantized); + _mm512_storeu_pd(output, dequantized); - output += 8; - input += BIT_WIDTH; - } + output += 8; + input += BIT_WIDTH; + } - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted = _mm512_cvtepi16_epi64(unpacked); + const __m512i converted = _mm512_cvtepi16_epi64(unpacked); - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - mm512_storeu_elements_pd(output, remaining_elements, dequantized); - } + mm512_storeu_elements_pd(output, remaining_elements, dequantized); + } - return 0; -} + return 0; + } -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); - if constexpr (iterations_16 > 0) { + if constexpr (iterations_16 > 0) { #pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24(source); + for (u32 i = 0; i < iterations_16; ++i) { + const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512 dequantized = mm512_dequantize_epi32(unpacked, scale_v); + const __m512 dequantized = mm512_dequantize_epi32(unpacked, scale_v); - _mm512_storeu_ps(output, dequantized); + _mm512_storeu_ps(output, dequantized); - output += 16; - input += 2 * BIT_WIDTH; - } - } + output += 16; + input += 2 * BIT_WIDTH; + } + } - if constexpr (iterations_8 > 0) { - const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); + if constexpr (iterations_8 > 0) { + const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); + const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); - _mm256_storeu_ps(output, dequantized); + _mm256_storeu_ps(output, dequantized); - output += 8; - input += BIT_WIDTH; - } + output += 8; + input += BIT_WIDTH; + } - if constexpr (remaining_elements > 0) { - const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); + if constexpr (remaining_elements > 0) { + const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); + const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); - mm256_storeu_elements_ps(output, remaining_elements, dequantized); - } + mm256_storeu_elements_ps(output, remaining_elements, dequantized); + } - return 0; -} + return 0; + } -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - const __m512d scale_v = _mm512_set1_pd(scale); + const __m512d scale_v = _mm512_set1_pd(scale); - if constexpr (iterations_16 > 0) { + if constexpr (iterations_16 > 0) { #pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24(source); + for (u32 i = 0; i < iterations_16; ++i) { + const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted1 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(unpacked)); - const __m512i converted2 = _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(unpacked, 1)); + const __m512i converted1 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(unpacked)); + const __m512i converted2 = _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(unpacked, 1)); - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); - output += 16; - input += 2 * BIT_WIDTH; - } - } + output += 16; + input += 2 * BIT_WIDTH; + } + } - if constexpr (iterations_8 > 0) { - const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); + if constexpr (iterations_8 > 0) { + const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted = _mm512_cvtepi32_epi64(unpacked); + const __m512i converted = _mm512_cvtepi32_epi64(unpacked); - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - _mm512_storeu_pd(output, dequantized); + _mm512_storeu_pd(output, dequantized); - output += 8; - input += BIT_WIDTH; - } + output += 8; + input += BIT_WIDTH; + } - if constexpr (remaining_elements > 0) { - const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); + if constexpr (remaining_elements > 0) { + const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); - const __m512i converted = _mm512_cvtepi32_epi64(unpacked); + const __m512i converted = _mm512_cvtepi32_epi64(unpacked); - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - mm512_storeu_elements_pd(output, remaining_elements, dequantized); - } + mm512_storeu_elements_pd(output, remaining_elements, dequantized); + } - return 0; -} -} // namespace internal + return 0; + } + } // namespace internal -/** + /** * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -511,24 +559,27 @@ __always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __re * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::mm512_decompress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::mm512_decompress_block_avx512vbmi_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::mm512_decompress_block_avx512vbmi_17to24(input, scale, output); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::mm512_decompress_block_avx512vbmi_1to8( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::mm512_decompress_block_avx512vbmi_9to16( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::mm512_decompress_block_avx512vbmi_17to24( + input, scale, output); + } + return 0; } - return 0; -} -/** + /** * @brief Decompress a single block to double values using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -538,24 +589,27 @@ __always_inline int mm512_decompress_block_avx512vbmi(const void* __restrict__ i * @param output pointer to the output buffer where decompressed double values will be stored. * @return int status code. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::mm512_decompress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::mm512_decompress_block_avx512vbmi_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::mm512_decompress_block_avx512vbmi_17to24(input, scale, output); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::mm512_decompress_block_avx512vbmi_1to8( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::mm512_decompress_block_avx512vbmi_9to16( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::mm512_decompress_block_avx512vbmi_17to24( + input, scale, output); + } + return 0; } - return 0; -} -/** + /** * @brief Decompress multiple 512\-bit blocks using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -569,26 +623,27 @@ __always_inline int mm512_decompress_block_avx512vbmi(const void* __restrict__ i * * @note This function requires AVX-512 and AVX-512-VBMI support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_decompress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_decompress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_decompress_block_avx512vbmi(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } -/** + /** * @brief Decompress multiple blocks to double values using AVX-512 and AVX-512-VBMI instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -599,23 +654,24 @@ int mm512_decompress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const * @param blocks number of blocks to decompress. * @return int status code. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_decompress_blocks_avx512vbmi(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_decompress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_decompress_block_avx512vbmi(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } - return 0; -} } // namespace pernix #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/src/internal/pernix/x86/avx512vbmi/compat.h b/src/internal/pernix/x86/avx512vbmi/compat.h index 6918d66..df9ca4b 100644 --- a/src/internal/pernix/x86/avx512vbmi/compat.h +++ b/src/internal/pernix/x86/avx512vbmi/compat.h @@ -7,377 +7,377 @@ #include namespace pernix::internal { - static __always_inline __mmask8 element_mask8(const uint32_t e) { + static __always_inline __mmask8 element_mask8(const u32 e) { return static_cast<__mmask8>(e >= 8 ? 0xFFu : ((1u << e) - 1u)); } - static __always_inline __mmask16 element_mask16(const uint32_t e) { + static __always_inline __mmask16 element_mask16(const u32 e) { return static_cast<__mmask16>(e >= 16 ? 0xFFFFu : ((1u << e) - 1u)); } - static __always_inline __mmask32 element_mask32(const uint32_t e) { + static __always_inline __mmask32 element_mask32(const u32 e) { return e >= 32 ? 0xFFFFFFFFu : (1u << e) - 1u; } - static __always_inline __mmask64 element_mask64(const uint32_t e) { + static __always_inline __mmask64 element_mask64(const u32 e) { return e >= 64 ? 0xFFFFFFFFFFFFFFFFull : (1ull << e) - 1ull; } - static __always_inline __m512i mm512_loadu_elements_epi64(const uint32_t e, const void *mem_addr) { + static __always_inline __m512i mm512_loadu_elements_epi64(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); + std::memcpy(&a, mem_addr, e * sizeof(i64)); return a; #else return _mm512_maskz_loadu_epi64(element_mask8(e), mem_addr); #endif } - static __always_inline __m256i mm256_loadu_elements_epi64(const uint32_t e, const void *mem_addr) { + static __always_inline __m256i mm256_loadu_elements_epi64(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); + std::memcpy(&a, mem_addr, e * sizeof(i64)); return a; #else return _mm256_maskz_loadu_epi64(element_mask8(e), mem_addr); #endif } - static __always_inline __m128i mm_loadu_elements_epi64(const uint32_t e, const void *mem_addr) { + static __always_inline __m128i mm_loadu_elements_epi64(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); + std::memcpy(&a, mem_addr, e * sizeof(i64)); return a; #else return _mm_maskz_loadu_epi64(element_mask8(e), mem_addr); #endif } - static __always_inline __m512i mm512_loadu_elements_epi32(const uint32_t e, const void *mem_addr) { + static __always_inline __m512i mm512_loadu_elements_epi32(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); + std::memcpy(&a, mem_addr, e * sizeof(i32)); return a; #else return _mm512_maskz_loadu_epi32(element_mask16(e), mem_addr); #endif } - static __always_inline __m256i mm256_loadu_elements_epi32(const uint32_t e, const void *mem_addr) { + static __always_inline __m256i mm256_loadu_elements_epi32(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); + std::memcpy(&a, mem_addr, e * sizeof(i32)); return a; #else return _mm256_maskz_loadu_epi32(element_mask8(e), mem_addr); #endif } - static __always_inline __m128i mm_loadu_elements_epi32(const uint32_t e, const void *mem_addr) { + static __always_inline __m128i mm_loadu_elements_epi32(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); + std::memcpy(&a, mem_addr, e * sizeof(i32)); return a; #else return _mm_maskz_loadu_epi32(element_mask8(e), mem_addr); #endif } - static __always_inline __m512i mm512_loadu_elements_epi16(const uint32_t e, const void *mem_addr) { + static __always_inline __m512i mm512_loadu_elements_epi16(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); + std::memcpy(&a, mem_addr, e * sizeof(i16)); return a; #else return _mm512_maskz_loadu_epi16(element_mask32(e), mem_addr); #endif } - static __always_inline __m256i mm256_loadu_elements_epi16(const uint32_t e, const void *mem_addr) { + static __always_inline __m256i mm256_loadu_elements_epi16(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); + std::memcpy(&a, mem_addr, e * sizeof(i16)); return a; #else return _mm256_maskz_loadu_epi16(element_mask16(e), mem_addr); #endif } - static __always_inline __m128i mm_loadu_elements_epi16(const uint32_t e, const void *mem_addr) { + static __always_inline __m128i mm_loadu_elements_epi16(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); + std::memcpy(&a, mem_addr, e * sizeof(i16)); return a; #else return _mm_maskz_loadu_epi16(element_mask8(e), mem_addr); #endif } - static __always_inline __m512i mm512_loadu_elements_epi8(const uint32_t e, const void *mem_addr) { + static __always_inline __m512i mm512_loadu_elements_epi8(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); + std::memcpy(&a, mem_addr, e * sizeof(i8)); return a; #else return _mm512_maskz_loadu_epi8(element_mask64(e), mem_addr); #endif } - static __always_inline __m256i mm256_loadu_elements_epi8(const uint32_t e, const void *mem_addr) { + static __always_inline __m256i mm256_loadu_elements_epi8(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); + std::memcpy(&a, mem_addr, e * sizeof(i8)); return a; #else return _mm256_maskz_loadu_epi8(element_mask32(e), mem_addr); #endif } - static __always_inline __m128i mm_loadu_elements_epi8(const uint32_t e, const void *mem_addr) { + static __always_inline __m128i mm_loadu_elements_epi8(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); + std::memcpy(&a, mem_addr, e * sizeof(i8)); return a; #else return _mm_maskz_loadu_epi8(element_mask16(e), mem_addr); #endif } - static __always_inline void mm512_storeu_elements_epi64(void *mem_addr, const uint32_t e, const __m512i a) { + static __always_inline void mm512_storeu_elements_epi64(void *mem_addr, const u32 e, const __m512i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; + alignas(64) u8 bytes[64]; _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i64)); #else _mm512_mask_storeu_epi64(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm256_storeu_elements_epi64(void *mem_addr, const uint32_t e, const __m256i a) { + static __always_inline void mm256_storeu_elements_epi64(void *mem_addr, const u32 e, const __m256i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; + alignas(32) u8 bytes[32]; _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i64)); #else _mm256_mask_storeu_epi64(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm_storeu_elements_epi64(void *mem_addr, const uint32_t e, const __m128i a) { + static __always_inline void mm_storeu_elements_epi64(void *mem_addr, const u32 e, const __m128i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; + alignas(16) u8 bytes[16]; _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i64)); #else _mm_mask_storeu_epi64(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm512_storeu_elements_epi32(void *mem_addr, const uint32_t e, const __m512i a) { + static __always_inline void mm512_storeu_elements_epi32(void *mem_addr, const u32 e, const __m512i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; + alignas(64) u8 bytes[64]; _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i32)); #else _mm512_mask_storeu_epi32(mem_addr, element_mask16(e), a); #endif } - static __always_inline void mm256_storeu_elements_epi32(void *mem_addr, const uint32_t e, const __m256i a) { + static __always_inline void mm256_storeu_elements_epi32(void *mem_addr, const u32 e, const __m256i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; + alignas(32) u8 bytes[32]; _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i32)); #else _mm256_mask_storeu_epi32(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm_storeu_elements_epi32(void *mem_addr, const uint32_t e, const __m128i a) { + static __always_inline void mm_storeu_elements_epi32(void *mem_addr, const u32 e, const __m128i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; + alignas(16) u8 bytes[16]; _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i32)); #else _mm_mask_storeu_epi32(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm512_storeu_elements_epi16(void *mem_addr, const uint32_t e, const __m512i a) { + static __always_inline void mm512_storeu_elements_epi16(void *mem_addr, const u32 e, const __m512i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; + alignas(64) u8 bytes[64]; _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i16)); #else _mm512_mask_storeu_epi16(mem_addr, element_mask32(e), a); #endif } - static __always_inline void mm256_storeu_elements_epi16(void *mem_addr, const uint32_t e, const __m256i a) { + static __always_inline void mm256_storeu_elements_epi16(void *mem_addr, const u32 e, const __m256i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; + alignas(32) u8 bytes[32]; _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i16)); #else _mm256_mask_storeu_epi16(mem_addr, element_mask16(e), a); #endif } - static __always_inline void mm_storeu_elements_epi16(void *mem_addr, const uint32_t e, const __m128i a) { + static __always_inline void mm_storeu_elements_epi16(void *mem_addr, const u32 e, const __m128i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; + alignas(16) u8 bytes[16]; _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i16)); #else _mm_mask_storeu_epi16(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm512_storeu_elements_epi8(void *mem_addr, const uint32_t e, const __m512i a) { + static __always_inline void mm512_storeu_elements_epi8(void *mem_addr, const u32 e, const __m512i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; + alignas(64) u8 bytes[64]; _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i8)); #else _mm512_mask_storeu_epi8(mem_addr, element_mask64(e), a); #endif } - static __always_inline void mm256_storeu_elements_epi8(void *mem_addr, const uint32_t e, const __m256i a) { + static __always_inline void mm256_storeu_elements_epi8(void *mem_addr, const u32 e, const __m256i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; + alignas(32) u8 bytes[32]; _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i8)); #else _mm256_mask_storeu_epi8(mem_addr, element_mask32(e), a); #endif } - static __always_inline void mm_storeu_elements_epi8(void *mem_addr, const uint32_t e, const __m128i a) { + static __always_inline void mm_storeu_elements_epi8(void *mem_addr, const u32 e, const __m128i a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; + alignas(16) u8 bytes[16]; _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); + std::memcpy(mem_addr, bytes, e * sizeof(i8)); #else _mm_mask_storeu_epi8(mem_addr, element_mask16(e), a); #endif } - static __always_inline __m512 mm512_loadu_elements_ps(const uint32_t e, const void *mem_addr) { + static __always_inline __m512 mm512_loadu_elements_ps(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m512 a = _mm512_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); + std::memcpy(&a, mem_addr, e * sizeof(f32)); return a; #else return _mm512_maskz_loadu_ps(element_mask16(e), mem_addr); #endif } - static __always_inline __m256 mm256_loadu_elements_ps(const uint32_t e, const void *mem_addr) { + static __always_inline __m256 mm256_loadu_elements_ps(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m256 a = _mm256_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); + std::memcpy(&a, mem_addr, e * sizeof(f32)); return a; #else return _mm256_maskz_loadu_ps(element_mask8(e), mem_addr); #endif } - static __always_inline __m128 mm_loadu_elements_ps(const uint32_t e, const void *mem_addr) { + static __always_inline __m128 mm_loadu_elements_ps(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m128 a = _mm_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); + std::memcpy(&a, mem_addr, e * sizeof(f32)); return a; #else return _mm_maskz_loadu_ps(element_mask8(e), mem_addr); #endif } - static __always_inline __m512d mm512_loadu_elements_pd(const uint32_t e, const void *mem_addr) { + static __always_inline __m512d mm512_loadu_elements_pd(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m512d a = _mm512_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); + std::memcpy(&a, mem_addr, e * sizeof(f64)); return a; #else return _mm512_maskz_loadu_pd(element_mask8(e), mem_addr); #endif } - static __always_inline __m256d mm256_loadu_elements_pd(const uint32_t e, const void *mem_addr) { + static __always_inline __m256d mm256_loadu_elements_pd(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m256d a = _mm256_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); + std::memcpy(&a, mem_addr, e * sizeof(f64)); return a; #else return _mm256_maskz_loadu_pd(element_mask8(e), mem_addr); #endif } - static __always_inline __m128d mm_loadu_elements_pd(const uint32_t e, const void *mem_addr) { + static __always_inline __m128d mm_loadu_elements_pd(const u32 e, const void *mem_addr) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) __m128d a = _mm_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); + std::memcpy(&a, mem_addr, e * sizeof(f64)); return a; #else return _mm_maskz_loadu_pd(element_mask8(e), mem_addr); #endif } - static __always_inline void mm512_storeu_elements_ps(void *mem_addr, const uint32_t e, const __m512 a) { + static __always_inline void mm512_storeu_elements_ps(void *mem_addr, const u32 e, const __m512 a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) float_t values[16]; + alignas(64) f32 values[16]; _mm512_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); + std::memcpy(mem_addr, values, e * sizeof(f32)); #else _mm512_mask_storeu_ps(mem_addr, element_mask16(e), a); #endif } - static __always_inline void mm256_storeu_elements_ps(void *mem_addr, const uint32_t e, const __m256 a) { + static __always_inline void mm256_storeu_elements_ps(void *mem_addr, const u32 e, const __m256 a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) float_t values[8]; + alignas(32) f32 values[8]; _mm256_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); + std::memcpy(mem_addr, values, e * sizeof(f32)); #else _mm256_mask_storeu_ps(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm_storeu_elements_ps(void *mem_addr, const uint32_t e, const __m128 a) { + static __always_inline void mm_storeu_elements_ps(void *mem_addr, const u32 e, const __m128 a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) float_t values[4]; + alignas(16) f32 values[4]; _mm_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); + std::memcpy(mem_addr, values, e * sizeof(f32)); #else _mm_mask_storeu_ps(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm512_storeu_elements_pd(void *mem_addr, const uint32_t e, const __m512d a) { + static __always_inline void mm512_storeu_elements_pd(void *mem_addr, const u32 e, const __m512d a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) double_t values[8]; + alignas(64) f64 values[8]; _mm512_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); + std::memcpy(mem_addr, values, e * sizeof(f64)); #else _mm512_mask_storeu_pd(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm256_storeu_elements_pd(void *mem_addr, const uint32_t e, const __m256d a) { + static __always_inline void mm256_storeu_elements_pd(void *mem_addr, const u32 e, const __m256d a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) double_t values[4]; + alignas(32) f64 values[4]; _mm256_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); + std::memcpy(mem_addr, values, e * sizeof(f64)); #else _mm256_mask_storeu_pd(mem_addr, element_mask8(e), a); #endif } - static __always_inline void mm_storeu_elements_pd(void *mem_addr, const uint32_t e, const __m128d a) { + static __always_inline void mm_storeu_elements_pd(void *mem_addr, const u32 e, const __m128d a) { #if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) double_t values[2]; + alignas(16) f64 values[2]; _mm_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); + std::memcpy(mem_addr, values, e * sizeof(f64)); #else _mm_mask_storeu_pd(mem_addr, element_mask8(e), a); #endif diff --git a/src/internal/pernix/x86/avx512vbmi/packing.h b/src/internal/pernix/x86/avx512vbmi/packing.h index e51c6cc..d4052eb 100644 --- a/src/internal/pernix/x86/avx512vbmi/packing.h +++ b/src/internal/pernix/x86/avx512vbmi/packing.h @@ -9,14 +9,14 @@ namespace pernix::internal { /** * @brief Pack 8 16-bit values for bit widths 9 through 16 using VBMI. */ - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i &input) { if constexpr (BIT_WIDTH == 16) { return input; } else { using tables = pack_tables_avx512_16; - const __m128i maskv = _mm_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i maskv = _mm_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); const __m128i masked = _mm_and_si128(input, maskv); if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { @@ -46,17 +46,17 @@ __always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i &input) { /** * @brief Pack 16 8-bit values for bit widths 1 through 8 using VBMI. */ - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i &input) { if constexpr (BIT_WIDTH == 8) { return input; } else { - const __m128i maskv = _mm_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i maskv = _mm_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); const __m128i masked = _mm_and_si128(input, maskv); if constexpr (BIT_WIDTH == 1) { - return _mm_set1_epi16(static_cast(_mm_cmpgt_epi8_mask(masked, _mm_setzero_si128()))); + return _mm_set1_epi16(static_cast(_mm_cmpgt_epi8_mask(masked, _mm_setzero_si128()))); } else if constexpr (BIT_WIDTH == 2) { const __m128i shifted = _mm_srli_epi16(masked, 6); const __m128i combined = _mm_or_si128(masked, shifted); @@ -91,12 +91,12 @@ __always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i &input) { /** * @brief Pack 4 32-bit values for bit widths 17 through 24 using VBMI. */ - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i &input) { using tables = pack_tables_avx512_24; - const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); const __m128i masked = _mm_and_si128(input, maskv); const __m128 permuted1 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute1()); @@ -115,14 +115,14 @@ __always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i &input) { /** * @brief Pack 16 16-bit values for bit widths 9 through 16 using VBMI. */ - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i &input) { if constexpr (BIT_WIDTH == 16) { return input; } else { using tables = pack_tables_avx512_16; - const __m256i maskv = _mm256_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i maskv = _mm256_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); const __m256i masked = _mm256_and_si256(input, maskv); if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { @@ -152,18 +152,18 @@ __always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i &input) /** * @brief Pack 32 8-bit values for bit widths 1 through 8 using VBMI. */ - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i &input) { if constexpr (BIT_WIDTH == 8) { return input; } else { - const __m256i maskv = _mm256_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i maskv = _mm256_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); const __m256i masked = _mm256_and_si256(input, maskv); if constexpr (BIT_WIDTH == 1) { return _mm256_set1_epi32( - static_cast(_mm256_cmpgt_epi8_mask(masked, _mm256_setzero_si256()))); + static_cast(_mm256_cmpgt_epi8_mask(masked, _mm256_setzero_si256()))); } else if constexpr (BIT_WIDTH == 2) { const __m256i shifted = _mm256_srli_epi16(masked, 6); const __m256i combined = _mm256_or_si256(masked, shifted); @@ -199,12 +199,12 @@ __always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i &input) { /** * @brief Pack 8 32-bit values for bit widths 17 through 24 using VBMI. */ - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i &input) { using tables = pack_tables_avx512_24; - const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); const __m256i masked = _mm256_and_si256(input, maskv); const __m256i permuted1 = _mm256_permutexvar_epi32(tables::get_permute1(), masked); @@ -223,14 +223,14 @@ __always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i &input) /** * @brief Pack 32 16-bit values for bit widths 9 through 16 using VBMI. */ - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i &input) { if constexpr (BIT_WIDTH == 16) { return input; } else { using tables = pack_tables_avx512_16; - const __m512i maskv = _mm512_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i maskv = _mm512_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); const __m512i masked = _mm512_and_si512(input, maskv); if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { @@ -260,18 +260,18 @@ __always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i &input) /** * @brief Pack 64 8-bit values for bit widths 1 through 8 using VBMI. */ - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i &input) { if constexpr (BIT_WIDTH == 8) { return input; } else { - const __m512i maskv = _mm512_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i maskv = _mm512_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); const __m512i masked = _mm512_and_si512(input, maskv); if constexpr (BIT_WIDTH == 1) { return _mm512_set1_epi64( - static_cast(_mm512_cmpgt_epi8_mask(masked, _mm512_setzero_si512()))); + static_cast(_mm512_cmpgt_epi8_mask(masked, _mm512_setzero_si512()))); } else if constexpr (BIT_WIDTH == 2) { const __m512i shifted = _mm512_srli_epi16(masked, 6); const __m512i combined = _mm512_or_si512(masked, shifted); @@ -307,12 +307,12 @@ __always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i &input) { /** * @brief Pack 16 32-bit values for bit widths 17 through 24 using VBMI. */ - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i &input) { using tables = pack_tables_avx512_24; - const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); const __m512i masked = _mm512_and_si512(input, maskv); const __m512i permuted1 = _mm512_permutexvar_epi32(tables::get_permute1(), masked); diff --git a/src/internal/pernix/x86/avx512vbmi/tables.h b/src/internal/pernix/x86/avx512vbmi/tables.h index ab3c665..d63a47a 100644 --- a/src/internal/pernix/x86/avx512vbmi/tables.h +++ b/src/internal/pernix/x86/avx512vbmi/tables.h @@ -9,25 +9,25 @@ #include namespace pernix::internal { - template - static __always_inline Vec load_table(const std::array &table) { - static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); - if constexpr (std::is_same_v) { - return _mm512_load_si512(static_cast(table.data())); - } else if constexpr (std::is_same_v) { - return _mm256_load_si256(reinterpret_cast(table.data())); - } else { - return _mm_load_si128(reinterpret_cast(table.data())); - } +template +static __always_inline Vec load_table(const std::array& table) { + static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); + if constexpr (std::is_same_v) { + return _mm512_load_si512(static_cast(table.data())); + } else if constexpr (std::is_same_v) { + return _mm256_load_si256(reinterpret_cast(table.data())); + } else { + return _mm_load_si128(reinterpret_cast(table.data())); } +} - template - requires(N >= 9 && N <= 15) - struct pack_tables_avx512_16 { - alignas(64) inline static constexpr std::array permute1 = [] { +template + requires(N >= 9 && N <= 15) +struct pack_tables_avx512_16 { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 0, 2, 4, 6, -1, 9, 11, 13, 15, 16, 18, 20, @@ -39,7 +39,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 0, 2, -1, 5, 7, 8, 10, -1, 13, 15, 16, 18, @@ -51,7 +51,7 @@ namespace pernix::internal { 45, 47, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 0, 2, 3, 5, 6, -1, 9, -1, 12, -1, 15, 16, @@ -63,7 +63,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15, @@ -75,7 +75,7 @@ namespace pernix::internal { 38, 39, 41, 42 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 0, -1, 3, 4, 5, -1, -1, 9, 10, -1, -1, 14, @@ -87,7 +87,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, @@ -99,7 +99,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, @@ -111,14 +111,14 @@ namespace pernix::internal { 30, 31, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 1, 3, 5, 7, 8, 10, 12, 14, -1, 17, 19, 21, @@ -130,7 +130,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 1, 3, 4, 6, -1, 9, 11, 12, 14, -1, 17, 19, @@ -142,7 +142,7 @@ namespace pernix::internal { 46, -1, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 1, -1, 4, -1, 7, 8, 10, 11, 13, 14, -1, 17, @@ -154,7 +154,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, @@ -166,7 +166,7 @@ namespace pernix::internal { 37, 38, 40, 41 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 1, 2, -1, -1, 6, 7, 8, -1, 11, 12, 13, -1, @@ -178,7 +178,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, @@ -190,7 +190,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -202,13 +202,13 @@ namespace pernix::internal { 29, 30, -1, -1 }; } - // clang-format on - }(); + // clang-format on + }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ -1, 1, 3, 5, 7, 8, 10, 12, 14, -1, 17, 19, @@ -220,7 +220,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ -1, 1, 3, 4, 6, -1, 9, 11, 12, 14, -1, 17, @@ -232,7 +232,7 @@ namespace pernix::internal { 44, 46, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ -1, 1, 2, 4, 5, 7, 8, 10, 11, 13, 14, -1, @@ -244,7 +244,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ -1, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, @@ -256,14 +256,14 @@ namespace pernix::internal { -1, -1, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 0, 2, 4, 6, 0, 1, 3, 5, 7, 0, 2, 4, @@ -275,7 +275,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 0, 4, 0, 2, 6, 0, 4, 0, 2, 6, 0, 4, @@ -287,7 +287,7 @@ namespace pernix::internal { 2, 6, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 0, 6, 1, 7, 2, 0, 3, 0, 4, 0, 5, 0, @@ -299,7 +299,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 12, 8, 4, 12, 8, 4, 12, 8, 4, 12, 8, 4, @@ -311,7 +311,7 @@ namespace pernix::internal { 8, 4, 12, 8 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 0, 0, 7, 4, 1, 0, 0, 5, 2, 0, 0, 6, @@ -323,7 +323,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 14, 12, 10, 8, 6, 4, 2, 14, 12, 10, 8, 6, @@ -335,7 +335,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, @@ -347,14 +347,14 @@ namespace pernix::internal { 2, 1, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 9, 11, 13, 15, 8, 10, 12, 14, 8, 9, 11, 13, @@ -366,7 +366,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 10, 14, 8, 12, 0, 10, 14, 8, 12, 0, 10, 14, @@ -378,7 +378,7 @@ namespace pernix::internal { 12, 0, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 11, 8, 12, 8, 13, 8, 14, 9, 15, 10, 8, 11, @@ -390,7 +390,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8, @@ -402,7 +402,7 @@ namespace pernix::internal { 4, 8, 0, 4 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 13, 10, 0, 0, 14, 11, 8, 0, 15, 12, 9, 0, @@ -414,7 +414,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 0, 2, 4, 6, 8, 10, 12, 0, 2, 4, 6, 8, @@ -426,7 +426,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -438,14 +438,14 @@ namespace pernix::internal { 13, 14, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 0, 7, 5, 3, 1, 8, 6, 4, 2, 0, 7, 5, @@ -457,7 +457,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 0, 6, 2, 8, 4, 0, 6, 2, 8, 4, 0, 6, @@ -469,7 +469,7 @@ namespace pernix::internal { 8, 4, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 0, 5, 10, 4, 9, 3, 8, 2, 7, 1, 6, 0, @@ -481,7 +481,7 @@ namespace pernix::internal { -1, -1, -1, -1 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 0, 3, 6, 9, 12, 2, 5, 8, 11, 1, 4, 7, @@ -493,11 +493,11 @@ namespace pernix::internal { -1, -1, -1, -1 }; } - return std::array{}; - // clang-format on - }(); + return std::array{}; + // clang-format on + }(); - inline static constexpr std::tuple<__mmask32, __mmask32, __mmask32> get_permute_masks() { + inline static constexpr std::tuple<__mmask32, __mmask32, __mmask32> get_permute_masks() { // clang-format off if constexpr (N == 9) { return { @@ -525,305 +525,305 @@ namespace pernix::internal { }; } return {0, 0, 0}; - // clang-format on - } - - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } - static __always_inline Vec get_permute3() { return load_table(permute3); } + // clang-format on + } - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } - static __always_inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } + + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && + (std::is_same_v || std::is_same_v || std::is_same_v)) +struct pack_tables_avx512_24 { +private: + struct word_plan { + i32 left_index1 = -1; + i32 left_index2 = -1; + i32 right_index = -1; + u32 left_shift1 = 32; + u32 left_shift2 = 32; + u32 right_shift = 32; }; - template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && - (std::is_same_v || std::is_same_v || std::is_same_v)) - struct pack_tables_avx512_24 { - private: - struct word_plan { - int32_t left_index1 = -1; - int32_t left_index2 = -1; - int32_t right_index = -1; - uint32_t left_shift1 = 32; - uint32_t left_shift2 = 32; - uint32_t right_shift = 32; - }; - - static constexpr word_plan create_plan(const uint32_t idx) { - word_plan plan{}; - - const uint32_t word_start = idx * 32u; - const uint32_t word_end = word_start + 32u; - - uint32_t left_slot = 0; - for (uint32_t input_lane = 0; input_lane < 16; ++input_lane) { - const uint32_t input_start = input_lane * BIT_WIDTH; - const uint32_t input_end = input_start + BIT_WIDTH; - - const uint32_t overlap_start = std::max(word_start, input_start); - const uint32_t overlap_end = std::min(word_end, input_end); - if (overlap_start >= overlap_end) { - continue; - } + static constexpr word_plan create_plan(const u32 idx) { + word_plan plan{}; - const auto output_bit = static_cast(overlap_start - word_start); - const auto input_bit = static_cast(overlap_start - input_start); - const int32_t delta = output_bit - input_bit; - - if (delta >= 0) { - if (left_slot == 0) { - plan.left_index1 = static_cast(input_lane); - plan.left_shift1 = static_cast(delta); - ++left_slot; - } else { - plan.left_index2 = static_cast(input_lane); - plan.left_shift2 = static_cast(delta); - } - } else { - plan.right_index = static_cast(input_lane); - plan.right_shift = static_cast(-delta); - } - } + const u32 word_start = idx * 32u; + const u32 word_end = word_start + 32u; - return plan; - } + u32 left_slot = 0; + for (u32 input_lane = 0; input_lane < 16; ++input_lane) { + const u32 input_start = input_lane * BIT_WIDTH; + const u32 input_end = input_start + BIT_WIDTH; - static constexpr std::array word_plans = [] { - std::array plans{}; - for (uint32_t i = 0; i < 16; ++i) { - plans[i] = create_plan(i); - } - return plans; - }(); - - template - static __always_inline constexpr std::array make_table(Getter getter) { - std::array values{}; - for (uint32_t i = 0; i < 16; ++i) { - values[i] = getter(word_plans[i]); + const u32 overlap_start = std::max(word_start, input_start); + const u32 overlap_end = std::min(word_end, input_end); + if (overlap_start >= overlap_end) { + continue; } - return values; - } - - alignas(64) static constexpr auto permute1 = make_table([](const word_plan &p) { - return p.left_index1; - }); - - alignas(64) static constexpr auto permute2 = make_table([](const word_plan &p) { - return p.left_index2; - }); - alignas(64) static constexpr auto permute3 = make_table([](const word_plan &p) { - return p.right_index; - }); + const auto output_bit = static_cast(overlap_start - word_start); + const auto input_bit = static_cast(overlap_start - input_start); + const i32 delta = output_bit - input_bit; - alignas(64) static constexpr auto shift1 = make_table( - [](const word_plan &p) { return p.left_shift1; }); - - alignas(64) static constexpr auto shift2 = make_table( - [](const word_plan &p) { return p.left_shift2; }); - - alignas(64) static constexpr auto shift3 = make_table( - [](const word_plan &p) { return p.right_shift; }); - - public: - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } - static __always_inline Vec get_permute3() { return load_table(permute3); } - - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } - static __always_inline Vec get_shift3() { return load_table(shift3); } - }; - - template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && - (std::is_same_v || std::is_same_v || std::is_same_v)) - struct unpack_tables_avx512_8 { - private: - alignas(64) inline static constexpr std::array permute1 = [] { - std::array table{}; - std::ranges::fill(table, -1); - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - - table[entry] = static_cast(first_byte); + if (delta >= 0) { + if (left_slot == 0) { + plan.left_index1 = static_cast(input_lane); + plan.left_shift1 = static_cast(delta); + ++left_slot; + } else { + plan.left_index2 = static_cast(input_lane); + plan.left_shift2 = static_cast(delta); + } + } else { + plan.right_index = static_cast(input_lane); + plan.right_shift = static_cast(-delta); } + } - return table; - }(); - - alignas(64) inline static constexpr std::array permute2 = [] { - std::array table{}; - std::ranges::fill(table, -1); + return plan; + } - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - const size_t bit_offset = bit_start % 8; + static constexpr std::array word_plans = [] { + std::array plans{}; + for (u32 i = 0; i < 16; ++i) { + plans[i] = create_plan(i); + } + return plans; + }(); + + template + static __always_inline constexpr std::array make_table(Getter getter) { + std::array values{}; + for (u32 i = 0; i < 16; ++i) { + values[i] = getter(word_plans[i]); + } + return values; + } - if (bit_offset + BIT_WIDTH > 8) { - table[entry] = static_cast(first_byte + 1); - } - } + alignas(64) static constexpr auto permute1 = make_table([](const word_plan& p) { + return p.left_index1; + }); + + alignas(64) static constexpr auto permute2 = make_table([](const word_plan& p) { + return p.left_index2; + }); + + alignas(64) static constexpr auto permute3 = make_table([](const word_plan& p) { + return p.right_index; + }); + + alignas(64) static constexpr auto shift1 = make_table( + [](const word_plan& p) { return p.left_shift1; }); + + alignas(64) static constexpr auto shift2 = make_table( + [](const word_plan& p) { return p.left_shift2; }); + + alignas(64) static constexpr auto shift3 = make_table( + [](const word_plan& p) { return p.right_shift; }); + +public: + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } + + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } +}; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && + (std::is_same_v || std::is_same_v || std::is_same_v)) +struct unpack_tables_avx512_8 { +private: + alignas(64) inline static constexpr std::array permute1 = [] { + std::array table{}; + std::ranges::fill(table, -1); + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + + table[entry] = static_cast(first_byte); + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array permute2 = [] { + std::array table{}; + std::ranges::fill(table, -1); - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8; + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + const size_t bit_offset = bit_start % 8; - table[entry] = static_cast(bit_offset); + if (bit_offset + BIT_WIDTH > 8) { + table[entry] = static_cast(first_byte + 1); } + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift2 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift1 = [] { + std::array table{}; - for (size_t entry = 0; entry < 64; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8u; - const size_t spill_bits = (bit_offset + BIT_WIDTH > 8u) ? (bit_offset + BIT_WIDTH - 8u) : 0u; + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8; - table[entry] = spill_bits ? static_cast(8 - bit_offset) : 0; - } + table[entry] = static_cast(bit_offset); + } - return table; - }(); + return table; + }(); - public: - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } + alignas(64) inline static constexpr std::array shift2 = [] { + std::array table{}; - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } - }; + for (size_t entry = 0; entry < 64; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8u; + const size_t spill_bits = (bit_offset + BIT_WIDTH > 8u) ? (bit_offset + BIT_WIDTH - 8u) : 0u; - template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && - (std::is_same_v || std::is_same_v || std::is_same_v)) - struct unpack_tables_avx512_16 { - private: - alignas(64) inline static constexpr std::array permute1 = [] { - std::array table{}; - std::ranges::fill(table, -1); - - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - const size_t base = entry * 2; - - table[base] = static_cast(first_byte); - table[base + 1] = static_cast(first_byte + 1); - } + table[entry] = spill_bits ? static_cast(8 - bit_offset) : 0; + } - return table; - }(); + return table; + }(); + +public: + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && + (std::is_same_v || std::is_same_v || std::is_same_v)) +struct unpack_tables_avx512_16 { +private: + alignas(64) inline static constexpr std::array permute1 = [] { + std::array table{}; + std::ranges::fill(table, -1); + + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + const size_t base = entry * 2; + + table[base] = static_cast(first_byte); + table[base + 1] = static_cast(first_byte + 1); + } - alignas(64) inline static constexpr std::array permute2 = [] { - std::array table{}; - std::ranges::fill(table, -1); + return table; + }(); - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t first_byte = bit_start / 8; - const size_t bit_offset = bit_start % 8; - const size_t base = entry * 2; + alignas(64) inline static constexpr std::array permute2 = [] { + std::array table{}; + std::ranges::fill(table, -1); - if (bit_offset + BIT_WIDTH > 16) { - table[base] = static_cast(first_byte + 2); - } + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t first_byte = bit_start / 8; + const size_t bit_offset = bit_start % 8; + const size_t base = entry * 2; + + if (bit_offset + BIT_WIDTH > 16) { + table[base] = static_cast(first_byte + 2); } + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift1 = [] { + std::array table{}; - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8u; + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8u; - // Right-shift the 16-bit chunk so the value starts at bit 0. - table[entry] = static_cast(bit_offset); - } + // Right-shift the 16-bit chunk so the value starts at bit 0. + table[entry] = static_cast(bit_offset); + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift2 = [] { - std::array table{}; - for (size_t entry = 0; entry < 32; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_offset = bit_start % 8u; - const size_t spill_bits = (bit_offset + BIT_WIDTH > 16u) ? (bit_offset + BIT_WIDTH - 16u) : 0u; + alignas(64) inline static constexpr std::array shift2 = [] { + std::array table{}; + for (size_t entry = 0; entry < 32; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_offset = bit_start % 8u; + const size_t spill_bits = (bit_offset + BIT_WIDTH > 16u) ? (bit_offset + BIT_WIDTH - 16u) : 0u; - // Move spill bits from byte3 to their final bit positions before merge. - table[entry] = spill_bits ? static_cast(16u - bit_offset) : 0; - } + // Move spill bits from byte3 to their final bit positions before merge. + table[entry] = spill_bits ? static_cast(16u - bit_offset) : 0; + } - return table; - }(); + return table; + }(); - public: - static __always_inline Vec get_permute1() { return load_table(permute1); } - static __always_inline Vec get_permute2() { return load_table(permute2); } +public: + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - static __always_inline Vec get_shift1() { return load_table(shift1); } - static __always_inline Vec get_shift2() { return load_table(shift2); } - }; + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } +}; - template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && - (std::is_same_v || std::is_same_v || std::is_same_v)) - struct unpack_tables_avx512_24 { - private: - alignas(64) inline static constexpr std::array permute = [] { - std::array table{}; - std::ranges::fill(table, -1); +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && + (std::is_same_v || std::is_same_v || std::is_same_v)) +struct unpack_tables_avx512_24 { +private: + alignas(64) inline static constexpr std::array permute = [] { + std::array table{}; + std::ranges::fill(table, -1); - for (size_t entry = 0; entry < 16; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - const size_t bit_end = bit_start + BIT_WIDTH - 1; + for (size_t entry = 0; entry < 16; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + const size_t bit_end = bit_start + BIT_WIDTH - 1; - const size_t first_byte = bit_start / 8; - const size_t last_byte = bit_end / 8; + const size_t first_byte = bit_start / 8; + const size_t last_byte = bit_end / 8; - const size_t base = entry * 4; + const size_t base = entry * 4; - for (size_t byte = first_byte; byte <= last_byte; ++byte) { - table[base + (byte - first_byte)] = static_cast(byte); - } + for (size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); } + } - return table; - }(); + return table; + }(); - alignas(64) inline static constexpr std::array shift = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift = [] { + std::array table{}; - for (size_t entry = 0; entry < 16; ++entry) { - const size_t bit_start = entry * BIT_WIDTH; - table[entry] = static_cast(32u - BIT_WIDTH - (bit_start % 8u)); - } + for (size_t entry = 0; entry < 16; ++entry) { + const size_t bit_start = entry * BIT_WIDTH; + table[entry] = static_cast(32u - BIT_WIDTH - (bit_start % 8u)); + } - return table; - }(); + return table; + }(); - public: - static __always_inline Vec get_permute() { return load_table(permute); } - static __always_inline Vec get_shift() { return load_table(shift); } - }; +public: + static __always_inline Vec get_permute() { return load_table(permute); } + static __always_inline Vec get_shift() { return load_table(shift); } +}; } // namespace pernix::internal #endif // PERNIX_AVX512VBMI_TABLES_H diff --git a/src/internal/pernix/x86/avx512vbmi/unpacking.h b/src/internal/pernix/x86/avx512vbmi/unpacking.h index 6799cd2..e66f9ec 100644 --- a/src/internal/pernix/x86/avx512vbmi/unpacking.h +++ b/src/internal/pernix/x86/avx512vbmi/unpacking.h @@ -22,7 +22,7 @@ __always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i coun return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); } -__always_inline static __m128i _mm_slli_epi8(const __m128i a, const int8_t imm8) { +__always_inline static __m128i _mm_slli_epi8(const __m128i a, const i8 imm8) { return _mm_sllv_epi8(a, _mm_set1_epi8(imm8)); } @@ -37,7 +37,7 @@ __always_inline static __m128i _mm_srli_epi8(const __m128i a, const int imm8) { return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); } -__always_inline static __m128i _mm_srai_epi8(const __m128i a, const int8_t imm8) { +__always_inline static __m128i _mm_srai_epi8(const __m128i a, const i8 imm8) { const __m128i lo_mask = _mm_set1_epi16(0x00ff); const __m128i hi_mask = _mm_set1_epi16(0xff00); const __m128i shift = _mm_cvtsi32_si128(imm8); @@ -50,7 +50,7 @@ __always_inline static __m128i _mm_srai_epi8(const __m128i a, const int8_t imm8) return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); } - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i &input) { if constexpr (BIT_WIDTH == 8) { @@ -107,7 +107,7 @@ __always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i &input) { const __mmask16 spill_mask = _mm_cmpneq_epi8_mask(tables::get_shift2(), _mm_setzero_si128()); __m128i combined = _mm_or_si128(shifted1, _mm_maskz_mov_epi8(spill_mask, shifted2)); - constexpr uint32_t shift = 8 - BIT_WIDTH; + constexpr u32 shift = 8 - BIT_WIDTH; combined = _mm_slli_epi8(combined, shift); if (SIGN_VALUES) { combined = _mm_srai_epi8(combined, shift); @@ -120,7 +120,7 @@ __always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i &input) { } } - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i &input) { if constexpr (BIT_WIDTH == 16) { @@ -138,7 +138,7 @@ __always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i &input) { shifted = _mm_or_si128(shifted, shifted2); } - constexpr uint32_t shift = 16 - BIT_WIDTH; + constexpr u32 shift = 16 - BIT_WIDTH; shifted = _mm_slli_epi16(shifted, shift); if (SIGN_VALUES) { shifted = _mm_srai_epi16(shifted, shift); @@ -150,14 +150,14 @@ __always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i &input) { } } - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline __m128i mm_unpack_epi32_avx512vbmi_17to24(const __m128i &input) { using tables = unpack_tables_avx512_24; const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute(), input); - constexpr uint32_t shift = 32 - BIT_WIDTH; + constexpr u32 shift = 32 - BIT_WIDTH; __m128i shifted = _mm_sllv_epi32(permuted, tables::get_shift()); if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { shifted = _mm_srai_epi32(shifted, shift); @@ -186,11 +186,11 @@ __always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i c return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); } -__always_inline static __m256i _mm256_slli_epi8(const __m256i a, const int8_t imm8) { +__always_inline static __m256i _mm256_slli_epi8(const __m256i a, const i8 imm8) { return _mm256_sllv_epi8(a, _mm256_set1_epi8(imm8)); } -__always_inline static __m256i _mm256_srli_epi8(const __m256i a, const int8_t imm8) { +__always_inline static __m256i _mm256_srli_epi8(const __m256i a, const i8 imm8) { const __m256i lo_mask = _mm256_set1_epi16(0x00ff); const __m256i hi_mask = _mm256_set1_epi16(0xff00); const __m128i shift = _mm_cvtsi32_si128(imm8); @@ -201,7 +201,7 @@ __always_inline static __m256i _mm256_srli_epi8(const __m256i a, const int8_t im return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); } -__always_inline static __m256i _mm256_srai_epi8(const __m256i a, const int8_t imm8) { +__always_inline static __m256i _mm256_srai_epi8(const __m256i a, const i8 imm8) { const __m256i lo_mask = _mm256_set1_epi16(0x00ff); const __m256i hi_mask = _mm256_set1_epi16(0xff00); const __m128i shift = _mm_cvtsi32_si128(imm8); @@ -214,7 +214,7 @@ __always_inline static __m256i _mm256_srai_epi8(const __m256i a, const int8_t im return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); } - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i &input) { if constexpr (BIT_WIDTH == 8) { @@ -271,7 +271,7 @@ __always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i &input) const __mmask32 spill_mask = _mm256_cmpneq_epi8_mask(tables::get_shift2(), _mm256_setzero_si256()); __m256i combined = _mm256_or_si256(shifted1, _mm256_maskz_mov_epi8(spill_mask, shifted2)); - constexpr uint32_t shift = 8 - BIT_WIDTH; + constexpr u32 shift = 8 - BIT_WIDTH; combined = _mm256_slli_epi8(combined, shift); if (SIGN_VALUES) { combined = _mm256_srai_epi8(combined, shift); @@ -284,7 +284,7 @@ __always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i &input) } } - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i &input) { if constexpr (BIT_WIDTH == 16) { @@ -302,7 +302,7 @@ __always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i &input shifted = _mm256_or_si256(shifted, shifted2); } - constexpr uint32_t shift = 16 - BIT_WIDTH; + constexpr u32 shift = 16 - BIT_WIDTH; shifted = _mm256_slli_epi16(shifted, shift); if (SIGN_VALUES) { shifted = _mm256_srai_epi16(shifted, shift); @@ -314,14 +314,14 @@ __always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i &input } } - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline __m256i mm256_unpack_epi32_avx512vbmi_17to24(const __m256i &input) { using tables = unpack_tables_avx512_24; const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute(), input); - constexpr uint32_t shift = 32 - BIT_WIDTH; + constexpr u32 shift = 32 - BIT_WIDTH; __m256i shifted = _mm256_sllv_epi32(permuted, tables::get_shift()); if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { shifted = _mm256_srai_epi32(shifted, shift); @@ -350,11 +350,11 @@ __always_inline static __m512i _mm512_sllv_epi8(const __m512i a, const __m512i c return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); } -__always_inline static __m512i _mm512_slli_epi8(const __m512i a, const int8_t imm8) { +__always_inline static __m512i _mm512_slli_epi8(const __m512i a, const i8 imm8) { return _mm512_sllv_epi8(a, _mm512_set1_epi8(imm8)); } -__always_inline static __m512i _mm512_srli_epi8(const __m512i a, const int8_t imm8) { +__always_inline static __m512i _mm512_srli_epi8(const __m512i a, const i8 imm8) { const __m512i lo_mask = _mm512_set1_epi16(0x00ff); const __m512i hi_mask = _mm512_set1_epi16(0xff00); const __m128i shift = _mm_cvtsi32_si128(imm8); @@ -365,7 +365,7 @@ __always_inline static __m512i _mm512_srli_epi8(const __m512i a, const int8_t im return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); } -__always_inline static __m512i _mm512_srai_epi8(const __m512i a, const int8_t imm8) { +__always_inline static __m512i _mm512_srai_epi8(const __m512i a, const i8 imm8) { const __m512i lo_mask = _mm512_set1_epi16(0x00ff); const __m512i hi_mask = _mm512_set1_epi16(0xff00); const __m128i shift = _mm_cvtsi32_si128(imm8); @@ -378,7 +378,7 @@ __always_inline static __m512i _mm512_srai_epi8(const __m512i a, const int8_t im return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); } - template + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) __always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i &input) { if constexpr (BIT_WIDTH == 8) { @@ -435,7 +435,7 @@ __always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i &input) const __mmask64 spill_mask = _mm512_cmpneq_epi8_mask(tables::get_shift2(), _mm512_setzero_si512()); __m512i combined = _mm512_or_si512(shifted1, _mm512_maskz_mov_epi8(spill_mask, shifted2)); - constexpr uint32_t shift = 8 - BIT_WIDTH; + constexpr u32 shift = 8 - BIT_WIDTH; combined = _mm512_slli_epi8(combined, shift); if (SIGN_VALUES) { combined = _mm512_srai_epi8(combined, shift); @@ -448,7 +448,7 @@ __always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i &input) } } - template + template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) __always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i &input) { if constexpr (BIT_WIDTH == 16) { @@ -465,7 +465,7 @@ __always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i &input shifted = _mm512_or_si512(shifted, shifted2); } - constexpr uint32_t shift = 16 - BIT_WIDTH; + constexpr u32 shift = 16 - BIT_WIDTH; shifted = _mm512_slli_epi16(shifted, shift); if (SIGN_VALUES) { shifted = _mm512_srai_epi16(shifted, shift); @@ -477,7 +477,7 @@ __always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i &input } } - template + template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) __always_inline __m512i mm512_unpack_epi32_avx512vbmi_17to24(const __m512i &input) { using tables = unpack_tables_avx512_24; @@ -485,7 +485,7 @@ __always_inline __m512i mm512_unpack_epi32_avx512vbmi_17to24(const __m512i &inpu const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute(), input); __m512i shifted = _mm512_sllv_epi32(permuted, tables::get_shift()); - constexpr uint32_t shift = 32 - BIT_WIDTH; + constexpr u32 shift = 32 - BIT_WIDTH; if constexpr (SIGN_VALUES) { shifted = _mm512_srai_epi32(shifted, shift); } else { diff --git a/src/internal/pernix/x86/bmi2/bmi2_compression.h b/src/internal/pernix/x86/bmi2/bmi2_compression.h index 2f2d362..e165c3a 100644 --- a/src/internal/pernix/x86/bmi2/bmi2_compression.h +++ b/src/internal/pernix/x86/bmi2/bmi2_compression.h @@ -10,147 +10,147 @@ #include namespace pernix { -namespace internal { -/** + namespace internal { + /** * @brief Build the masks and shift constants used by the BMI2 packers. * * @tparam BIT_WIDTH bit width per packed value. -* @return std::tuple mask tuple used by the BMI2 helpers. +* @return std::tuple mask tuple used by the BMI2 helpers. */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) -static constexpr std::tuple pack_avx2_bmi2_constants() { - uint32_t mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; - uint64_t pext_mask; - uint16_t shift1 = BIT_WIDTH * 4; - uint16_t shift2 = 64 - shift1; - - if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { - pext_mask = 0x0101010101010101ULL * mask; - } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { - pext_mask = 0x0001000100010001ULL * mask; - } else { - pext_mask = 0x0000000100000001ULL * mask; - } + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) + static constexpr std::tuple pack_avx2_bmi2_constants() { + u32 mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; + u64 pext_mask; + u16 shift1 = BIT_WIDTH * 4; + u16 shift2 = 64 - shift1; + + if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { + pext_mask = 0x0101010101010101ULL * mask; + } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { + pext_mask = 0x0001000100010001ULL * mask; + } else { + pext_mask = 0x0000000100000001ULL * mask; + } - return { - mask, - pext_mask, - shift1, - shift2, - }; -} + return { + mask, + pext_mask, + shift1, + shift2, + }; + } -/** + /** * @brief Pack four 32-bit values with BMI2 extract instructions. * * @tparam BIT_WIDTH bit width per packed value. * @param input SIMD register containing four quantized values. * @return __m128i packed bitstream in the low bytes of the result. */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) -static inline auto mm_pack_epi32_bmi2(const __m128i& input) -> __m128i { - const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); - - if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 16) { - const __m128i packed = _mm_packs_epi32(input, _mm_setzero_si128()); - const uint64_t value = _pext_u64(_mm_extract_epi64(packed, 0), pext_mask); - - const __m128i result = _mm_set_epi64x(0, value); - return result; - } else { - alignas(16) uint64_t values[2]; - values[0] = _pext_u64(_mm_extract_epi64(input, 0), pext_mask); - - const uint64_t temp_combined = _pext_u64(_mm_extract_epi64(input, 1), pext_mask); - values[1] = temp_combined >> shift2; - values[0] |= (temp_combined << shift1); - - const __m128i result = _mm_set_epi64x(static_cast(values[1]), static_cast(values[0])); - return result; - } -} + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) + static inline auto mm_pack_epi32_bmi2(const __m128i &input) -> __m128i { + const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); + + if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 16) { + const __m128i packed = _mm_packs_epi32(input, _mm_setzero_si128()); + const u64 value = _pext_u64(_mm_extract_epi64(packed, 0), pext_mask); -/** + const __m128i result = _mm_set_epi64x(0, value); + return result; + } else { + alignas(16) u64 values[2]; + values[0] = _pext_u64(_mm_extract_epi64(input, 0), pext_mask); + + const u64 temp_combined = _pext_u64(_mm_extract_epi64(input, 1), pext_mask); + values[1] = temp_combined >> shift2; + values[0] |= (temp_combined << shift1); + + const __m128i result = _mm_set_epi64x(static_cast(values[1]), static_cast(values[0])); + return result; + } + } + + /** * @brief Pack eight 32-bit values with BMI2 extract instructions. * * @tparam BIT_WIDTH bit width per packed value. * @param input SIMD register containing eight quantized values. * @return __m256i packed bitstream in the low bytes of the result. */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -static inline auto mm256_pack_epi32_bmi2(const __m256i& input) -> __m256i { - const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); - - if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { - const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); - const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); - const __m256i packed8 = _mm256_packs_epi16(permuted, _mm256_setzero_si256()); - const uint64_t value = _pext_u64(_mm256_extract_epi64(packed8, 0), pext_mask); - - const __m256i result = _mm256_setr_epi64x(static_cast(value), 0, 0, 0); - return result; - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); - alignas(16) int64_t values[2] = {}; - values[0] = _pext_u64(_mm256_extract_epi64(packed16, 0), pext_mask); - - const uint64_t temp_combined = _pext_u64(_mm256_extract_epi64(packed16, 2), pext_mask); - values[1] = temp_combined >> shift2; - if constexpr (BIT_WIDTH != 16) { - values[0] |= static_cast(temp_combined << shift1); - } - - const __m256i result = _mm256_setr_epi64x(values[0], values[1], 0, 0); - return result; - } else { - constexpr uint32_t chunk_bits = BIT_WIDTH * 2; // bits extracted per 64-bit lane - static_assert(chunk_bits < 64); - - constexpr uint64_t chunk_mask = (chunk_bits == 64) ? ~uint64_t{0} : ((uint64_t{1} << chunk_bits) - 1); - - const uint64_t x0 = _pext_u64(_mm256_extract_epi64(input, 0), pext_mask) & chunk_mask; - const uint64_t x1 = _pext_u64(_mm256_extract_epi64(input, 1), pext_mask) & chunk_mask; - const uint64_t x2 = _pext_u64(_mm256_extract_epi64(input, 2), pext_mask) & chunk_mask; - const uint64_t x3 = _pext_u64(_mm256_extract_epi64(input, 3), pext_mask) & chunk_mask; - - uint64_t out0 = 0; - uint64_t out1 = 0; - uint64_t out2 = 0; - - auto append_bits = [&](uint64_t value, uint32_t bit_offset) { - const uint32_t word = bit_offset >> 6; // / 64 - const uint32_t off = bit_offset & 63; // % 64 - - if (word == 0) { - out0 |= value << off; - if (off + chunk_bits > 64) { - out1 |= value >> (64 - off); - } - } else if (word == 1) { - out1 |= value << off; - if (off + chunk_bits > 64) { - out2 |= value >> (64 - off); + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + static inline auto mm256_pack_epi32_bmi2(const __m256i &input) -> __m256i { + const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); + + if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { + const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); + const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i packed8 = _mm256_packs_epi16(permuted, _mm256_setzero_si256()); + const u64 value = _pext_u64(_mm256_extract_epi64(packed8, 0), pext_mask); + + const __m256i result = _mm256_setr_epi64x(static_cast(value), 0, 0, 0); + return result; + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); + alignas(16) i64 values[2] = {}; + values[0] = _pext_u64(_mm256_extract_epi64(packed16, 0), pext_mask); + + const u64 temp_combined = _pext_u64(_mm256_extract_epi64(packed16, 2), pext_mask); + values[1] = temp_combined >> shift2; + if constexpr (BIT_WIDTH != 16) { + values[0] |= static_cast(temp_combined << shift1); } + + const __m256i result = _mm256_setr_epi64x(values[0], values[1], 0, 0); + return result; } else { - out2 |= value << off; + constexpr u32 chunk_bits = BIT_WIDTH * 2; // bits extracted per 64-bit lane + static_assert(chunk_bits < 64); + + constexpr u64 chunk_mask = (chunk_bits == 64) ? ~u64{0} : ((u64{1} << chunk_bits) - 1); + + const u64 x0 = _pext_u64(_mm256_extract_epi64(input, 0), pext_mask) & chunk_mask; + const u64 x1 = _pext_u64(_mm256_extract_epi64(input, 1), pext_mask) & chunk_mask; + const u64 x2 = _pext_u64(_mm256_extract_epi64(input, 2), pext_mask) & chunk_mask; + const u64 x3 = _pext_u64(_mm256_extract_epi64(input, 3), pext_mask) & chunk_mask; + + u64 out0 = 0; + u64 out1 = 0; + u64 out2 = 0; + + auto append_bits = [&](u64 value, u32 bit_offset) { + const u32 word = bit_offset >> 6; // / 64 + const u32 off = bit_offset & 63; // % 64 + + if (word == 0) { + out0 |= value << off; + if (off + chunk_bits > 64) { + out1 |= value >> (64 - off); + } + } else if (word == 1) { + out1 |= value << off; + if (off + chunk_bits > 64) { + out2 |= value >> (64 - off); + } + } else { + out2 |= value << off; + } + }; + + append_bits(x0, 0 * chunk_bits); + append_bits(x1, 1 * chunk_bits); + append_bits(x2, 2 * chunk_bits); + append_bits(x3, 3 * chunk_bits); + + return _mm256_setr_epi64x(static_cast(out0), static_cast(out1), + static_cast(out2), 0); } - }; - - append_bits(x0, 0 * chunk_bits); - append_bits(x1, 1 * chunk_bits); - append_bits(x2, 2 * chunk_bits); - append_bits(x3, 3 * chunk_bits); - - return _mm256_setr_epi64x(static_cast(out0), static_cast(out1), - static_cast(out2), 0); - } -} -} // namespace internal + } + } // namespace internal -/** + /** * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -162,47 +162,47 @@ static inline auto mm256_pack_epi32_bmi2(const __m256i& input) -> __m256i { * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_bmi2(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; - std::memset(output, 0, BLOCK_SIZE); + std::memset(output, 0, BLOCK_SIZE); - const __m256 scale_v = _mm256_set1_ps(scale); + const __m256 scale_v = _mm256_set1_ps(scale); #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); - const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); - const __m256i packed = internal::mm256_pack_epi32_bmi2(packed_input); - std::memcpy(output, &packed, BIT_WIDTH); - input += 8; - output += BIT_WIDTH; - } + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); + const __m256i packed_input = internal::mm256_clamp_signed_epi32 < BIT_WIDTH > (quantized); + const __m256i packed = internal::mm256_pack_epi32_bmi2(packed_input); + std::memcpy(output, &packed, BIT_WIDTH); + input += 8; + output += BIT_WIDTH; + } - if constexpr (remaining) { - std::vector block_values(remaining); + if constexpr (remaining) { + std::vector block_values(remaining); #pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized( - internal::quantize_ps_epi32(input[i], scale))); + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_ps_epi32(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); } - internal::pack_epi32_fallback(block_values, output); + return 0; } - return 0; -} - -/** + /** * @brief Compress a single block of double values using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -213,50 +213,50 @@ int mm256_compress_block_bmi2(const void* __restrict__ input_ptr, const float_t * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_bmi2(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; - std::memset(output, 0, BLOCK_SIZE); + std::memset(output, 0, BLOCK_SIZE); - const __m256d scale_v = _mm256_set1_pd(scale); + const __m256d scale_v = _mm256_set1_pd(scale); #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256d source1 = _mm256_loadu_pd(input); - const __m256d source2 = _mm256_loadu_pd(input + 4); - const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); - const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); - __m256i combined = _mm256_castsi128_si256(quantized1); - combined = _mm256_inserti128_si256(combined, quantized2, 1); - const __m256i packed = internal::mm256_pack_epi32_bmi2( - internal::mm256_clamp_signed_epi32(combined)); - std::memcpy(output, &packed, BIT_WIDTH); - input += 8; - output += BIT_WIDTH; - } + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256d source1 = _mm256_loadu_pd(input); + const __m256d source2 = _mm256_loadu_pd(input + 4); + const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); + const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); + __m256i combined = _mm256_castsi128_si256(quantized1); + combined = _mm256_inserti128_si256(combined, quantized2, 1); + const __m256i packed = internal::mm256_pack_epi32_bmi2( + internal::mm256_clamp_signed_epi32 < BIT_WIDTH > (combined)); + std::memcpy(output, &packed, BIT_WIDTH); + input += 8; + output += BIT_WIDTH; + } - if constexpr (remaining) { - std::vector block_values(remaining); + if constexpr (remaining) { + std::vector block_values(remaining); #pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized( - internal::quantize_pd_epi64(input[i], scale))); - } + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_pd_epi64(input[i], scale))); + } - internal::pack_epi32_fallback(block_values, output); + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); + } + return 0; } - return 0; -} -/** + /** * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -269,26 +269,26 @@ int mm256_compress_block_bmi2(const void* __restrict__ input_ptr, const double_t * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_bmi2(const void* __restrict__ input_ptr, const float_t scale, - void* __restrict__ output_ptr, const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_bmi2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_bmi2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } -/** + /** * @brief Compress multiple blocks of double values using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -300,24 +300,24 @@ int mm256_compress_blocks_bmi2(const void* __restrict__ input_ptr, const float_t * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_bmi2(const void* __restrict__ input_ptr, const double_t scale, - void* __restrict__ output_ptr, const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_bmi2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_bmi2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } - return 0; -} + return 0; + } } // namespace pernix #endif // PERNIX_BMI2_COMPRESSION_H diff --git a/src/internal/pernix/x86/bmi2/bmi2_decompression.h b/src/internal/pernix/x86/bmi2/bmi2_decompression.h index 414d6e0..275a11a 100644 --- a/src/internal/pernix/x86/bmi2/bmi2_decompression.h +++ b/src/internal/pernix/x86/bmi2/bmi2_decompression.h @@ -10,48 +10,48 @@ #include namespace pernix { -namespace internal { -/** + namespace internal { + /** * @brief Sign-extend packed values after BMI2 expansion into 32-bit lanes. * * @tparam BIT_WIDTH original encoded bit width. * @param source register containing unpacked values. * @return __m128i sign-extended values. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m128i mm_sign_extend32(__m128i source) { - if constexpr (BIT_WIDTH == 1) { - // Keep 1-bit values as 0/1 to match fallback decoding semantics. - return source; - } - - constexpr uint16_t shift = 32 - BIT_WIDTH; - source = _mm_slli_epi32(source, shift); - return _mm_srai_epi32(source, shift); -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m128i mm_sign_extend32(__m128i source) { + if constexpr (BIT_WIDTH == 1) { + // Keep 1-bit values as 0/1 to match fallback decoding semantics. + return source; + } + + constexpr u16 shift = 32 - BIT_WIDTH; + source = _mm_slli_epi32(source, shift); + return _mm_srai_epi32(source, shift); + } -/** + /** * @brief Sign-extend packed values after BMI2 expansion into eight 32-bit lanes. * * @tparam BIT_WIDTH original encoded bit width. * @param source register containing unpacked values. * @return __m256i sign-extended values. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m256i mm256_sign_extend32(__m256i source) { - if constexpr (BIT_WIDTH == 1) { - // Keep 1-bit values as 0/1 to match fallback decoding semantics. - return source; - } - - constexpr uint16_t shift = 32 - BIT_WIDTH; - source = _mm256_slli_epi32(source, shift); - return _mm256_srai_epi32(source, shift); -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m256i mm256_sign_extend32(__m256i source) { + if constexpr (BIT_WIDTH == 1) { + // Keep 1-bit values as 0/1 to match fallback decoding semantics. + return source; + } + + constexpr u16 shift = 32 - BIT_WIDTH; + source = _mm256_slli_epi32(source, shift); + return _mm256_srai_epi32(source, shift); + } -/** + /** * @brief Unpack four values from a BMI2-packed input buffer. * * @tparam BIT_WIDTH bit width per packed value. @@ -59,58 +59,58 @@ __m256i mm256_sign_extend32(__m256i source) { * @param input pointer to the packed input buffer. * @return __m128i unpacked values. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH > 0 && BIT_WIDTH <= 24) -__m128i mm_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { - constexpr uint32_t mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; - constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; - - __m128i result; - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - constexpr uint64_t pdep_mask = 0x0101010101010101ULL * mask; - - uint32_t temp_value = 0; - std::memcpy(&temp_value, input, packed_bytes); - - const int32_t value = _pdep_u32(temp_value, static_cast(pdep_mask)); - const __m128i source = _mm_insert_epi32(_mm_setzero_si128(), value, 0); - - result = _mm_cvtepi8_epi32(source); - } else if constexpr (BIT_WIDTH == 16) { - const __m128i source = _mm_loadu_si64(input); - result = _mm_cvtepi16_epi32(source); - } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { - constexpr uint64_t pdep_mask = 0x0001000100010001ULL * mask; - - uint64_t temp_value = 0; - std::memcpy(&temp_value, input, packed_bytes); - - const int64_t value = _pdep_u64(temp_value, pdep_mask); - const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); - - result = _mm_cvtepi16_epi32(source); - } else { - constexpr uint64_t pdep_mask = 0x0000000100000001ULL * mask; - constexpr uint32_t shift1 = BIT_WIDTH * 2; - constexpr uint32_t shift2 = 64 - shift1; - - alignas(16) uint64_t temp_values[2]{}; - std::memcpy(temp_values, input, packed_bytes); - - alignas(16) int64_t values[2]; - values[0] = _pdep_u64(temp_values[0], pdep_mask); - values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); - - result = _mm_set_epi64x(values[1], values[0]); - } - - if constexpr (SIGN_VALUES) { - result = internal::mm_sign_extend32(result); - } - return result; -} + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH > 0 && BIT_WIDTH <= 24) + __m128i mm_unpack_epi32_bmi2(const u8 * __restrict__ input) { + constexpr u32 mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; + constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; + + __m128i result; + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + constexpr u64 pdep_mask = 0x0101010101010101ULL * mask; + + u32 temp_value = 0; + std::memcpy(&temp_value, input, packed_bytes); + + const i32 value = _pdep_u32(temp_value, static_cast(pdep_mask)); + const __m128i source = _mm_insert_epi32(_mm_setzero_si128(), value, 0); + + result = _mm_cvtepi8_epi32(source); + } else if constexpr (BIT_WIDTH == 16) { + const __m128i source = _mm_loadu_si64(input); + result = _mm_cvtepi16_epi32(source); + } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { + constexpr u64 pdep_mask = 0x0001000100010001ULL * mask; + + u64 temp_value = 0; + std::memcpy(&temp_value, input, packed_bytes); + + const i64 value = _pdep_u64(temp_value, pdep_mask); + const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); + + result = _mm_cvtepi16_epi32(source); + } else { + constexpr u64 pdep_mask = 0x0000000100000001ULL * mask; + constexpr u32 shift1 = BIT_WIDTH * 2; + constexpr u32 shift2 = 64 - shift1; + + alignas(16) u64 temp_values[2]{}; + std::memcpy(temp_values, input, packed_bytes); + + alignas(16) i64 values[2]; + values[0] = _pdep_u64(temp_values[0], pdep_mask); + values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); + + result = _mm_set_epi64x(values[1], values[0]); + } + + if constexpr (SIGN_VALUES) { + result = internal::mm_sign_extend32(result); + } + return result; + } -/** + /** * @brief Unpack eight values from a BMI2-packed input buffer. * * @tparam BIT_WIDTH bit width per packed value. @@ -118,80 +118,82 @@ __m128i mm_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { * @param input pointer to the packed input buffer. * @return __m256i unpacked values. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m256i mm256_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { - constexpr uint32_t mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; - constexpr std::size_t packed_bytes = BIT_WIDTH; - - __m256i result; - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - constexpr uint64_t pdep_mask = 0x0101010101010101ULL * mask; - - uint64_t temp_value = 0; - std::memcpy(&temp_value, input, packed_bytes); - - const int64_t value = _pdep_u64(temp_value, pdep_mask); - const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); - - result = _mm256_cvtepi8_epi32(source); - } else if constexpr (BIT_WIDTH == 16) { - const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); - result = _mm256_cvtepi16_epi32(source); - } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { - constexpr uint64_t pdep_mask = 0x0001000100010001ULL * mask; - constexpr uint64_t shift1 = BIT_WIDTH * 4; - constexpr uint64_t shift2 = 64 - shift1; - - alignas(16) uint64_t temp_values[2]{}; - std::memcpy(temp_values, input, packed_bytes); - - alignas(16) int64_t values[2]; - values[0] = _pdep_u64(temp_values[0], pdep_mask); - values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); - - const __m128i source = _mm_set_epi64x(values[1], values[0]); - result = _mm256_cvtepi16_epi32(source); - } else { - constexpr uint64_t pdep_mask = 0x0000000100000001ULL * mask; - constexpr uint32_t shift1 = BIT_WIDTH * 2; - constexpr uint32_t shift2 = 64 - shift1; - - alignas(16) uint64_t temp_values[4]{}; - std::memcpy(temp_values, input, packed_bytes); - - if constexpr ((BIT_WIDTH % 2) == 0) { - std::memcpy(temp_values + 2, input + BIT_WIDTH / 2, packed_bytes - BIT_WIDTH / 2); - } else { - constexpr uint32_t second_group_bit_offset = BIT_WIDTH * 4; - constexpr uint32_t second_group_byte_offset = second_group_bit_offset / 8; - constexpr uint32_t second_group_shift = second_group_bit_offset % 8; - - alignas(16) uint64_t raw_values[2]{}; - std::memcpy(raw_values, input + second_group_byte_offset, packed_bytes - second_group_byte_offset); - - temp_values[2] = (raw_values[0] >> second_group_shift) | (raw_values[1] << (64 - second_group_shift)); - temp_values[3] = raw_values[1] >> second_group_shift; + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m256i mm256_unpack_epi32_bmi2(const u8 * __restrict__ input) { + constexpr u32 mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; + constexpr std::size_t packed_bytes = BIT_WIDTH; + + __m256i result; + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + constexpr u64 pdep_mask = 0x0101010101010101ULL * mask; + + u64 temp_value = 0; + std::memcpy(&temp_value, input, packed_bytes); + + const i64 value = _pdep_u64(temp_value, pdep_mask); + const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); + + result = _mm256_cvtepi8_epi32(source); + } else if constexpr (BIT_WIDTH == 16) { + const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); + result = _mm256_cvtepi16_epi32(source); + } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { + constexpr u64 pdep_mask = 0x0001000100010001ULL * mask; + constexpr u64 shift1 = BIT_WIDTH * 4; + constexpr u64 shift2 = 64 - shift1; + + alignas(16) u64 temp_values[2]{}; + std::memcpy(temp_values, input, packed_bytes); + + alignas(16) i64 values[2]; + values[0] = _pdep_u64(temp_values[0], pdep_mask); + values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); + + const __m128i source = _mm_set_epi64x(values[1], values[0]); + result = _mm256_cvtepi16_epi32(source); + } else { + constexpr u64 pdep_mask = 0x0000000100000001ULL * mask; + constexpr u32 shift1 = BIT_WIDTH * 2; + constexpr u32 shift2 = 64 - shift1; + + alignas(16) u64 temp_values[4]{}; + std::memcpy(temp_values, input, packed_bytes); + + if constexpr ((BIT_WIDTH % 2) == 0) { + std::memcpy(temp_values + 2, input + BIT_WIDTH / 2, packed_bytes - BIT_WIDTH / 2); + } else { + constexpr u32 second_group_bit_offset = BIT_WIDTH * 4; + constexpr u32 second_group_byte_offset = second_group_bit_offset / 8; + constexpr u32 second_group_shift = second_group_bit_offset % 8; + + alignas(16) u64 raw_values[2]{}; + std::memcpy(raw_values, input + second_group_byte_offset, packed_bytes - second_group_byte_offset); + + temp_values[2] = (raw_values[0] >> second_group_shift) | ( + raw_values[1] << (64 - second_group_shift)); + temp_values[3] = raw_values[1] >> second_group_shift; + } + + alignas(16) u64 values[4]; + values[0] = _pdep_u64((temp_values[0]), pdep_mask); + values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); + values[2] = _pdep_u64((temp_values[2]), pdep_mask); + values[3] = _pdep_u64((temp_values[2] >> shift1) | (temp_values[3] << shift2), pdep_mask); + + result = _mm256_set_epi64x(static_cast(values[3]), static_cast(values[2]), + static_cast(values[1]), + static_cast(values[0])); + } + + if constexpr (SIGN_VALUES) { + result = internal::mm256_sign_extend32(result); + } + return result; } + } // namespace internal - alignas(16) uint64_t values[4]; - values[0] = _pdep_u64((temp_values[0]), pdep_mask); - values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); - values[2] = _pdep_u64((temp_values[2]), pdep_mask); - values[3] = _pdep_u64((temp_values[2] >> shift1) | (temp_values[3] << shift2), pdep_mask); - - result = _mm256_set_epi64x(static_cast(values[3]), static_cast(values[2]), static_cast(values[1]), - static_cast(values[0])); - } - - if constexpr (SIGN_VALUES) { - result = internal::mm256_sign_extend32(result); - } - return result; -} -} // namespace internal - -/** + /** * @brief Decompress a single 512\-bit block using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 16). @@ -205,37 +207,40 @@ __m256i mm256_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_bmi2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - const __m256 scale_v = _mm256_set1_ps(scale); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + const __m256 scale_v = _mm256_set1_ps(scale); #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); - const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); - _mm256_storeu_ps(output, dequantized); - input += BIT_WIDTH; - output += 8; - } + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); + const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); + _mm256_storeu_ps(output, dequantized); + input += BIT_WIDTH; + output += 8; + } - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi32(tail_values[i], scale); + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi32(tail_values[i], scale); + } } - } - return 0; -} + return 0; + } -/** + /** * @brief Decompress a single block to double values using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 16). @@ -249,43 +254,46 @@ int mm256_decompress_block_bmi2(const void* __restrict__ input_ptr, const float_ * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_bmi2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - const __m256d scale_v = _mm256_set1_pd(scale); + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + const __m256d scale_v = _mm256_set1_pd(scale); #pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); - const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); - const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); + const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); + const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); - const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); - const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); + const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); + const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); - _mm256_storeu_pd(output, dequantized1); - _mm256_storeu_pd(output + 4, dequantized2); + _mm256_storeu_pd(output, dequantized1); + _mm256_storeu_pd(output + 4, dequantized2); - input += BIT_WIDTH; - output += 8; - } + input += BIT_WIDTH; + output += 8; + } - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi64(tail_values[i], scale); + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi64(tail_values[i], scale); + } } - } - return 0; -} + return 0; + } -/** + /** * @brief Decompress multiple 512\-bit blocks using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -300,26 +308,27 @@ int mm256_decompress_block_bmi2(const void* __restrict__ input_ptr, const double * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_bmi2(const void* __restrict__ input_ptr, const float_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_bmi2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_bmi2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } -/** + /** * @brief Decompress multiple blocks to double values using AVX2 and BMI2 instructions. * * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). @@ -334,24 +343,25 @@ int mm256_decompress_blocks_bmi2(const void* __restrict__ input_ptr, const float * * @note This function requires AVX2 and BMI2 support. */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_bmi2(const void* __restrict__ input_ptr, const double_t scale, void* __restrict__ output_ptr, - const uint32_t blocks) { - const auto* input = static_cast(input_ptr); - auto* output = static_cast(output_ptr); - - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_bmi2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_bmi2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } - return 0; -} + return 0; + } } // namespace pernix #endif // PERNIX_BMI2_DECOMPRESSION_H diff --git a/src/internal/pernix/x86/utils.h b/src/internal/pernix/x86/utils.h index 185e0a2..42aad76 100644 --- a/src/internal/pernix/x86/utils.h +++ b/src/internal/pernix/x86/utils.h @@ -4,13 +4,11 @@ #include namespace pernix::x86::internal { - -static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - return tail_bytes; -} - -} // namespace pernix::x86::internal + static constexpr u32 tail_bytes(const u8 bit_width, const u32 remaining_elements) { + const u32 tail_bits = remaining_elements * bit_width; + const u32 tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; + } +} // namespace pernix::x86::internal #endif // PERNIX_X86_UTILS_H diff --git a/src/pernix.cpp b/src/pernix.cpp index 8b806a6..a57a2e5 100644 --- a/src/pernix.cpp +++ b/src/pernix.cpp @@ -2,14 +2,14 @@ #include namespace { -bool is_valid_block_size(uint32_t block_size) { - return block_size == 64 || block_size == 128 || block_size == 256 || block_size == 512 || block_size == 1024; -} + bool is_valid_block_size(u32 block_size) { + return block_size == 64 || block_size == 128 || block_size == 256 || block_size == 512 || block_size == 1024; + } } extern "C" { -pernix_status pernix_compress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - float scale, void* output) { +pernix_status pernix_compress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -17,7 +17,8 @@ pernix_status pernix_compress_block_f32(pernix_backend backend, uint8_t bit_widt return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_compress_block_f32(static_cast(backend), bit_width, block_size); + const auto kernel = pernix::internal::select_compress_block_f32(static_cast(backend), bit_width, + block_size); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -26,8 +27,8 @@ pernix_status pernix_compress_block_f32(pernix_backend backend, uint8_t bit_widt return static_cast(kernel.func(input, scale, output)); } -pernix_status pernix_compress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - float scale, void* output, uint32_t blocks) { +pernix_status pernix_compress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output, u32 blocks) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -36,7 +37,8 @@ pernix_status pernix_compress_blocks_f32(pernix_backend backend, uint8_t bit_wid return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_compress_blocks_f32(static_cast(backend), bit_width, block_size); + const auto kernel = pernix::internal::select_compress_blocks_f32(static_cast(backend), bit_width, + block_size); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -45,8 +47,8 @@ pernix_status pernix_compress_blocks_f32(pernix_backend backend, uint8_t bit_wid return static_cast(kernel.func(input, scale, output, blocks)); } -pernix_status pernix_decompress_block_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - float scale, void* output, bool sign_values) { +pernix_status pernix_decompress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output, bool sign_values) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -55,8 +57,9 @@ pernix_status pernix_decompress_block_f32(pernix_backend backend, uint8_t bit_wi return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_decompress_block_f32(static_cast(backend), bit_width, block_size, - sign_values); + const auto kernel = pernix::internal::select_decompress_block_f32(static_cast(backend), bit_width, + block_size, + sign_values); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -65,8 +68,8 @@ pernix_status pernix_decompress_block_f32(pernix_backend backend, uint8_t bit_wi return static_cast(kernel.func(input, scale, output)); } -pernix_status pernix_decompress_blocks_f32(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - float scale, void* output, uint32_t blocks, bool sign_values) { +pernix_status pernix_decompress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output, u32 blocks, bool sign_values) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -75,8 +78,9 @@ pernix_status pernix_decompress_blocks_f32(pernix_backend backend, uint8_t bit_w return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_decompress_blocks_f32(static_cast(backend), bit_width, block_size, - sign_values); + const auto kernel = pernix::internal::select_decompress_blocks_f32(static_cast(backend), bit_width, + block_size, + sign_values); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -85,8 +89,8 @@ pernix_status pernix_decompress_blocks_f32(pernix_backend backend, uint8_t bit_w return static_cast(kernel.func(input, scale, output, blocks)); } -pernix_status pernix_compress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - double scale, void* output) { +pernix_status pernix_compress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -95,7 +99,8 @@ pernix_status pernix_compress_block_f64(pernix_backend backend, uint8_t bit_widt return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_compress_block_f64(static_cast(backend), bit_width, block_size); + const auto kernel = pernix::internal::select_compress_block_f64(static_cast(backend), bit_width, + block_size); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -104,8 +109,8 @@ pernix_status pernix_compress_block_f64(pernix_backend backend, uint8_t bit_widt return static_cast(kernel.func(input, scale, output)); } -pernix_status pernix_compress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - double scale, void* output, uint32_t blocks) { +pernix_status pernix_compress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output, u32 blocks) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -114,7 +119,8 @@ pernix_status pernix_compress_blocks_f64(pernix_backend backend, uint8_t bit_wid return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_compress_blocks_f64(static_cast(backend), bit_width, block_size); + const auto kernel = pernix::internal::select_compress_blocks_f64(static_cast(backend), bit_width, + block_size); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -123,8 +129,8 @@ pernix_status pernix_compress_blocks_f64(pernix_backend backend, uint8_t bit_wid return static_cast(kernel.func(input, scale, output, blocks)); } -pernix_status pernix_decompress_block_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - double scale, void* output, bool sign_values) { +pernix_status pernix_decompress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output, bool sign_values) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -133,8 +139,9 @@ pernix_status pernix_decompress_block_f64(pernix_backend backend, uint8_t bit_wi return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_decompress_block_f64(static_cast(backend), bit_width, block_size, - sign_values); + const auto kernel = pernix::internal::select_decompress_block_f64(static_cast(backend), bit_width, + block_size, + sign_values); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; @@ -143,8 +150,8 @@ pernix_status pernix_decompress_block_f64(pernix_backend backend, uint8_t bit_wi return static_cast(kernel.func(input, scale, output)); } -pernix_status pernix_decompress_blocks_f64(pernix_backend backend, uint8_t bit_width, uint32_t block_size, const void* input, - double scale, void* output, uint32_t blocks, bool sign_values) { +pernix_status pernix_decompress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output, u32 blocks, bool sign_values) { if (input == nullptr || output == nullptr) { return PERNIX_STATUS_INVALID_ARGUMENT; } @@ -153,8 +160,9 @@ pernix_status pernix_decompress_blocks_f64(pernix_backend backend, uint8_t bit_w return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; } - const auto kernel = pernix::internal::select_decompress_blocks_f64(static_cast(backend), bit_width, block_size, - sign_values); + const auto kernel = pernix::internal::select_decompress_blocks_f64(static_cast(backend), bit_width, + block_size, + sign_values); if (!kernel) { return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; diff --git a/src/x86/avx2/avx2_compression.cpp b/src/x86/avx2/avx2_compression.cpp index 1a7f2c0..548989b 100644 --- a/src/x86/avx2/avx2_compression.cpp +++ b/src/x86/avx2/avx2_compression.cpp @@ -134,53 +134,53 @@ case N: return Kernel("avx2", &mm256_compress_blocks_avx2 select_avx2_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); - default: - return {"avx2", nullptr}; + Kernel select_avx2_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"avx2", nullptr}; + } } -} -Kernel select_avx2_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); - default: - return {"avx2", nullptr}; + Kernel select_avx2_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"avx2", nullptr}; + } } -} -Kernel select_avx2_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); - default: - return {"avx2", nullptr}; + Kernel select_avx2_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"avx2", nullptr}; + } } -} -Kernel select_avx2_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); - default: - return {"avx2", nullptr}; + Kernel select_avx2_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"avx2", nullptr}; + } } -} #undef PERNIX_CASE_COMPRESS_BLOCK_32 #undef PERNIX_CASE_COMPRESS_BLOCKS_32 diff --git a/src/x86/avx2/avx2_decompression.cpp b/src/x86/avx2/avx2_decompression.cpp index 43558ae..cf8a53c 100644 --- a/src/x86/avx2/avx2_decompression.cpp +++ b/src/x86/avx2/avx2_decompression.cpp @@ -146,49 +146,53 @@ case N: \ } \ break -Kernel select_avx2_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); - default: return {"avx2", nullptr}; + Kernel select_avx2_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"avx2", nullptr}; + } } -} -Kernel select_avx2_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); - default: return {"avx2", nullptr}; + Kernel select_avx2_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"avx2", nullptr}; + } } -} -Kernel select_avx2_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); - default: return {"avx2", nullptr}; + Kernel select_avx2_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"avx2", nullptr}; + } } -} -Kernel select_avx2_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); - default: return {"avx2", nullptr}; + Kernel select_avx2_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"avx2", nullptr}; + } } -} #undef PERNIX_CASE_DECOMPRESS_BLOCK_32 #undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 diff --git a/src/x86/avx512vbmi/avx512vbmi_compression.cpp b/src/x86/avx512vbmi/avx512vbmi_compression.cpp index f1c36a1..8eb25c5 100644 --- a/src/x86/avx512vbmi/avx512vbmi_compression.cpp +++ b/src/x86/avx512vbmi/avx512vbmi_compression.cpp @@ -134,53 +134,53 @@ case N: return Kernel("avx512vbmi", &mm512_compress_blocks_ default: return {"avx512vbmi", nullptr}; \ } -Kernel select_avx512vbmi_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); - default: - return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"avx512vbmi", nullptr}; + } } -} -Kernel select_avx512vbmi_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); - default: - return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"avx512vbmi", nullptr}; + } } -} -Kernel select_avx512vbmi_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); - default: - return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"avx512vbmi", nullptr}; + } } -} -Kernel select_avx512vbmi_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); - default: - return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"avx512vbmi", nullptr}; + } } -} #undef PERNIX_CASE_COMPRESS_BLOCK_32 #undef PERNIX_CASE_COMPRESS_BLOCKS_32 diff --git a/src/x86/avx512vbmi/avx512vbmi_decompression.cpp b/src/x86/avx512vbmi/avx512vbmi_decompression.cpp index 43f45fc..ac2cbcd 100644 --- a/src/x86/avx512vbmi/avx512vbmi_decompression.cpp +++ b/src/x86/avx512vbmi/avx512vbmi_decompression.cpp @@ -146,49 +146,53 @@ case N: \ } \ break -Kernel select_avx512vbmi_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); - default: return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"avx512vbmi", nullptr}; + } } -} -Kernel select_avx512vbmi_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); - default: return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"avx512vbmi", nullptr}; + } } -} -Kernel select_avx512vbmi_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); - default: return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"avx512vbmi", nullptr}; + } } -} -Kernel select_avx512vbmi_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); - default: return {"avx512vbmi", nullptr}; + Kernel select_avx512vbmi_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"avx512vbmi", nullptr}; + } } -} #undef PERNIX_CASE_DECOMPRESS_BLOCK_32 #undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 diff --git a/src/x86/bmi2/bmi2_compression.cpp b/src/x86/bmi2/bmi2_compression.cpp index ea01e04..e978ba2 100644 --- a/src/x86/bmi2/bmi2_compression.cpp +++ b/src/x86/bmi2/bmi2_compression.cpp @@ -134,53 +134,53 @@ case N: return Kernel("bmi2", &mm256_compress_blocks_bmi2 select_bmi2_compress_block_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); - default: - return {"bmi2", nullptr}; + Kernel select_bmi2_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"bmi2", nullptr}; + } } -} -Kernel select_bmi2_compress_blocks_f32(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); - default: - return {"bmi2", nullptr}; + Kernel select_bmi2_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"bmi2", nullptr}; + } } -} -Kernel select_bmi2_compress_block_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); - default: - return {"bmi2", nullptr}; + Kernel select_bmi2_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"bmi2", nullptr}; + } } -} -Kernel select_bmi2_compress_blocks_f64(const uint8_t bit_width, const uint32_t block_size) { - switch (block_size) { - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); - default: - return {"bmi2", nullptr}; + Kernel select_bmi2_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"bmi2", nullptr}; + } } -} #undef PERNIX_CASE_COMPRESS_BLOCK_32 #undef PERNIX_CASE_COMPRESS_BLOCKS_32 diff --git a/src/x86/bmi2/bmi2_decompression.cpp b/src/x86/bmi2/bmi2_decompression.cpp index e829c24..84deda8 100644 --- a/src/x86/bmi2/bmi2_decompression.cpp +++ b/src/x86/bmi2/bmi2_decompression.cpp @@ -146,49 +146,53 @@ case N: \ } \ break -Kernel select_bmi2_decompress_block_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); - default: return {"bmi2", nullptr}; + Kernel select_bmi2_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"bmi2", nullptr}; + } } -} -Kernel select_bmi2_decompress_blocks_f32(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); - default: return {"bmi2", nullptr}; + Kernel select_bmi2_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"bmi2", nullptr}; + } } -} -Kernel select_bmi2_decompress_block_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); - default: return {"bmi2", nullptr}; + Kernel select_bmi2_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"bmi2", nullptr}; + } } -} -Kernel select_bmi2_decompress_blocks_f64(const uint8_t bit_width, const uint32_t block_size, bool sign_values) { - switch (block_size) { - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); - PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); - default: return {"bmi2", nullptr}; + Kernel select_bmi2_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"bmi2", nullptr}; + } } -} #undef PERNIX_CASE_DECOMPRESS_BLOCK_32 #undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 diff --git a/tests/fallback_tests.cpp b/tests/fallback_tests.cpp index 5e2b4ef..abb502b 100644 --- a/tests/fallback_tests.cpp +++ b/tests/fallback_tests.cpp @@ -12,9 +12,9 @@ // --------------------------------------------------------------------------- TYPED_TEST(CompressionTest, FallbackCompressBlock) { - std::vector> compressed(this->testSet.numberOfBlocks); + std::vector > compressed(this->testSet.numberOfBlocks); - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { compressed[b].resize(TestFixture::BlockSize); const auto status = pernix_compress_block_f32( @@ -24,15 +24,15 @@ TYPED_TEST(CompressionTest, FallbackCompressBlock) { ASSERT_EQ(status, PERNIX_STATUS_OK); } - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { expectCompressedBlockEqualsReference(*this, compressed[b], b); } } TYPED_TEST(CompressionTest64, FallbackCompressBlock) { - std::vector> compressed(this->testSet.numberOfBlocks); + std::vector > compressed(this->testSet.numberOfBlocks); - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { compressed[b].resize(TestFixture::BlockSize); const auto status = pernix_compress_block_f64( @@ -42,7 +42,7 @@ TYPED_TEST(CompressionTest64, FallbackCompressBlock) { ASSERT_EQ(status, PERNIX_STATUS_OK); } - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { expectCompressedBlockEqualsReference(*this, compressed[b], b); } } @@ -52,9 +52,9 @@ TYPED_TEST(CompressionTest64, FallbackCompressBlock) { // --------------------------------------------------------------------------- TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { - std::vector> decompressed(this->testSet.numberOfBlocks); + std::vector > decompressed(this->testSet.numberOfBlocks); - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { decompressed[b].resize(this->testSet.elementsPerBlock); const auto status = pernix_decompress_block_f32( @@ -64,15 +64,15 @@ TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { ASSERT_EQ(status, PERNIX_STATUS_OK); } - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { expectDecompressedBlockNearSource(*this, decompressed[b], b); } } TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { - std::vector> decompressed(this->testSet.numberOfBlocks); + std::vector > decompressed(this->testSet.numberOfBlocks); - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { decompressed[b].resize(this->testSet.elementsPerBlock); const auto status = pernix_decompress_block_f64( @@ -82,7 +82,7 @@ TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { ASSERT_EQ(status, PERNIX_STATUS_OK); } - for (uint32_t b = 0; b < this->testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { expectDecompressedBlockNearSource(*this, decompressed[b], b); } } @@ -92,77 +92,77 @@ TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { // --------------------------------------------------------------------------- TYPED_TEST(CompressionTest, FallbackCompressBlocksRoundtrip) { - const uint32_t nb = this->testSet.numberOfBlocks; - const uint32_t epb = this->testSet.elementsPerBlock; - const uint32_t total = nb * epb; + const u32 nb = this->testSet.numberOfBlocks; + const u32 epb = this->testSet.elementsPerBlock; + const u32 total = nb * epb; - std::vector flat(total); - for (uint32_t b = 0; b < nb; b++) { + std::vector flat(total); + for (u32 b = 0; b < nb; b++) { std::copy_n(this->testSet.getDecompressedData()[b].data(), epb, flat.data() + b * epb); } // Compute a single scale that covers all blocks float max_abs = 0.0f; - for (uint32_t i = 0; i < total; i++) { + for (u32 i = 0; i < total; i++) { max_abs = std::max(max_abs, std::abs(flat[i])); } const float q = static_cast(decltype(this->testSet)::quantization_levels); - const float scale = (max_abs > 0.0f && q > 0.0f) ? (max_abs / q) : std::numeric_limits::epsilon(); + const float scale = (max_abs > 0.0f && q > 0.0f) ? (max_abs / q) : std::numeric_limits::epsilon(); const float scale_inv = 1.0f / scale; - std::vector compressed(nb * TestFixture::BlockSize); + std::vector compressed(nb * TestFixture::BlockSize); auto status = pernix_compress_blocks_f32( PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, flat.data(), scale_inv, compressed.data(), nb); ASSERT_EQ(status, PERNIX_STATUS_OK); - std::vector restored(total); + std::vector restored(total); status = pernix_decompress_blocks_f32( PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, compressed.data(), scale, restored.data(), nb, true); ASSERT_EQ(status, PERNIX_STATUS_OK); - const float tol = (std::abs(scale) * 0.5f) + (std::numeric_limits::epsilon() * 16.0f); - for (uint32_t i = 0; i < total; i++) { + const float tol = (std::abs(scale) * 0.5f) + (std::numeric_limits::epsilon() * 16.0f); + for (u32 i = 0; i < total; i++) { EXPECT_NEAR(restored[i], flat[i], tol); } } TYPED_TEST(CompressionTest64, FallbackCompressBlocksRoundtrip) { - const uint32_t nb = this->testSet.numberOfBlocks; - const uint32_t epb = this->testSet.elementsPerBlock; - const uint32_t total = nb * epb; + const u32 nb = this->testSet.numberOfBlocks; + const u32 epb = this->testSet.elementsPerBlock; + const u32 total = nb * epb; - std::vector flat(total); - for (uint32_t b = 0; b < nb; b++) { + std::vector flat(total); + for (u32 b = 0; b < nb; b++) { std::copy_n(this->testSet.getDecompressedData()[b].data(), epb, flat.data() + b * epb); } // Compute a single scale that covers all blocks double max_abs = 0.0; - for (uint32_t i = 0; i < total; i++) { + for (u32 i = 0; i < total; i++) { max_abs = std::max(max_abs, std::abs(flat[i])); } const double q = static_cast(decltype(this->testSet)::quantization_levels); - const double scale = (max_abs > 0.0 && q > 0.0) ? (max_abs / q) : std::numeric_limits::epsilon(); + const double scale = (max_abs > 0.0 && q > 0.0) ? (max_abs / q) : std::numeric_limits::epsilon(); const double scale_inv = 1.0 / scale; - std::vector compressed(nb * TestFixture::BlockSize); + std::vector compressed(nb * TestFixture::BlockSize); auto status = pernix_compress_blocks_f64( PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, flat.data(), scale_inv, compressed.data(), nb); ASSERT_EQ(status, PERNIX_STATUS_OK); - std::vector restored(total); + std::vector restored(total); status = pernix_decompress_blocks_f64( PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, compressed.data(), scale, restored.data(), nb, true); ASSERT_EQ(status, PERNIX_STATUS_OK); - const double tol = (std::abs(scale) * 0.5) + (std::numeric_limits::epsilon() * 16.0); - for (uint32_t i = 0; i < total; i++) { + const double tol = (std::abs(scale) * 0.5) + (std::numeric_limits::epsilon() * 16.0); + for (u32 i = 0; i < total; i++) { EXPECT_NEAR(restored[i], flat[i], tol); } } @@ -172,11 +172,11 @@ TYPED_TEST(CompressionTest64, FallbackCompressBlocksRoundtrip) { // --------------------------------------------------------------------------- TYPED_TEST(CompressionTest, SingleBlockCompressBlocksMatchesBlock) { - const auto& src = this->testSet.getDecompressedData()[0]; + const auto &src = this->testSet.getDecompressedData()[0]; const float scale_inv = 1.0f / this->testSet.getScales()[0]; - std::vector blockOut(TestFixture::BlockSize); - std::vector blocksOut(TestFixture::BlockSize); + std::vector blockOut(TestFixture::BlockSize); + std::vector blocksOut(TestFixture::BlockSize); auto s1 = pernix_compress_block_f32( PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, @@ -188,18 +188,18 @@ TYPED_TEST(CompressionTest, SingleBlockCompressBlocksMatchesBlock) { src.data(), scale_inv, blocksOut.data(), 1); ASSERT_EQ(s2, PERNIX_STATUS_OK); - for (uint32_t i = 0; i < TestFixture::BlockSize; i++) { + for (u32 i = 0; i < TestFixture::BlockSize; i++) { EXPECT_EQ(blockOut[i], blocksOut[i]) << "byte " << i; } } TYPED_TEST(DecompressionTest, SingleBlockDecompressBlocksMatchesBlock) { - const auto& compressed = this->testSet.getCompressedData()[0]; + const auto &compressed = this->testSet.getCompressedData()[0]; const float scale = this->testSet.getScales()[0]; - const uint32_t epb = this->testSet.elementsPerBlock; + const u32 epb = this->testSet.elementsPerBlock; - std::vector blockOut(epb); - std::vector blocksOut(epb); + std::vector blockOut(epb); + std::vector blocksOut(epb); auto s1 = pernix_decompress_block_f32( PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, @@ -211,7 +211,7 @@ TYPED_TEST(DecompressionTest, SingleBlockDecompressBlocksMatchesBlock) { compressed.data(), scale, blocksOut.data(), 1, true); ASSERT_EQ(s2, PERNIX_STATUS_OK); - for (uint32_t i = 0; i < epb; i++) { + for (u32 i = 0; i < epb; i++) { EXPECT_FLOAT_EQ(blockOut[i], blocksOut[i]) << "element " << i; } } @@ -221,11 +221,11 @@ TYPED_TEST(DecompressionTest, SingleBlockDecompressBlocksMatchesBlock) { // --------------------------------------------------------------------------- TEST(FallbackEdgeTest, SignExtensionIsWellDefinedForNegativeValues) { - constexpr uint32_t BS = 64; - const std::array input{0x08}; + constexpr u32 BS = 64; + const std::array input{0x08}; pernix_status st; - std::array output{}; + std::array output{}; st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, 4, BS, input.data(), 1.0f, output.data(), true); ASSERT_EQ(st, PERNIX_STATUS_OK); @@ -233,12 +233,12 @@ TEST(FallbackEdgeTest, SignExtensionIsWellDefinedForNegativeValues) { } TEST(FallbackEdgeTest, ClearsUnusedPaddingBytes) { - constexpr uint32_t BS = 64; - constexpr uint32_t BW = 24; - constexpr uint32_t EPB = (BS * 8) / BW; + constexpr u32 BS = 64; + constexpr u32 BW = 24; + constexpr u32 EPB = (BS * 8) / BW; - std::array input{}; - std::array output{}; + std::array input{}; + std::array output{}; output.fill(0xA5); auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, input.data(), 1.0f, output.data()); @@ -247,17 +247,17 @@ TEST(FallbackEdgeTest, ClearsUnusedPaddingBytes) { } TEST(FallbackEdgeTest, ClampsNonFiniteAndOutOfRangeBeforeNarrowing) { - constexpr uint32_t BS = 64; - constexpr uint32_t BW = 4; - constexpr uint32_t EPB = (BS * 8) / BW; + constexpr u32 BS = 64; + constexpr u32 BW = 4; + constexpr u32 EPB = (BS * 8) / BW; - std::array input{}; - input[0] = std::numeric_limits::infinity(); - input[1] = -std::numeric_limits::infinity(); - input[2] = std::numeric_limits::quiet_NaN(); + std::array input{}; + input[0] = std::numeric_limits::infinity(); + input[1] = -std::numeric_limits::infinity(); + input[2] = std::numeric_limits::quiet_NaN(); - std::array compressed{}; - std::array restored{}; + std::array compressed{}; + std::array restored{}; auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, input.data(), 1.0f, compressed.data()); ASSERT_EQ(st, PERNIX_STATUS_OK); @@ -275,9 +275,9 @@ TEST(FallbackEdgeTest, ClampsNonFiniteAndOutOfRangeBeforeNarrowing) { // --------------------------------------------------------------------------- TEST(ErrorCodeTest, UnsupportedBlockSizeReturnsError) { - constexpr uint32_t BS = 32; - float_t src[32] = {}; - uint8_t dst[32] = {}; + constexpr u32 BS = 32; + f32 src[32] = {}; + u8 dst[32] = {}; auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, BS, src, 1.0f, dst); EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); @@ -293,9 +293,9 @@ TEST(ErrorCodeTest, UnsupportedBlockSizeReturnsError) { } TEST(ErrorCodeTest, UnsupportedBitWidthReturnsError) { - constexpr uint32_t BS = 64; - float_t src[256] = {}; - uint8_t dst[64] = {}; + constexpr u32 BS = 64; + f32 src[256] = {}; + u8 dst[64] = {}; auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 0, BS, src, 1.0f, dst); EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH); @@ -305,8 +305,8 @@ TEST(ErrorCodeTest, UnsupportedBitWidthReturnsError) { } TEST(ErrorCodeTest, NullPointerReturnsError) { - float_t src[64] = {}; - uint8_t dst[64] = {}; + f32 src[64] = {}; + u8 dst[64] = {}; auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, 64, nullptr, 1.0f, dst); EXPECT_EQ(st, PERNIX_STATUS_INVALID_ARGUMENT); diff --git a/tests/include/testset.h b/tests/include/testset.h index 81acb32..b73fc71 100644 --- a/tests/include/testset.h +++ b/tests/include/testset.h @@ -21,11 +21,11 @@ static_assert(PERNIX_TEST_BLOCK_SIZE % 32 == 0, "PERNIX_TEST_BLOCK_SIZE must be dividable by 32 bytes"); -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && std::is_floating_point_v class TestSet { - using ValueType = uint8_t; - using SeedType = std::mt19937::result_type; + using ValueType = u8; + using SeedType = std::mt19937::result_type; alignas(64) std::vector > compressedData; alignas(64) std::vector > sourceData; @@ -36,21 +36,24 @@ class TestSet { std::uniform_real_distribution dis{}; public: - static constexpr uint32_t elementsPerBlock = (BLOCK_SIZE * 8) / BIT_WIDTH; - static constexpr SeedType defaultSeed = 0x5eed1234u; + static constexpr u32 elementsPerBlock = (BLOCK_SIZE * 8) / BIT_WIDTH; + static constexpr SeedType defaultSeed = 0x5eed1234u; static constexpr T quantization_levels = - SIGN_VALUES ? static_cast(BIT_WIDTH == 1 ? 1u : ((1u << (BIT_WIDTH - 1u)) - 1u)) : static_cast((1u << BIT_WIDTH) - 1u); + SIGN_VALUES + ? static_cast(BIT_WIDTH == 1 ? 1u : ((1u << (BIT_WIDTH - 1u)) - 1u)) + : static_cast((1u << BIT_WIDTH) - 1u); - uint32_t numberOfBlocks; + u32 numberOfBlocks; - [[nodiscard]] constexpr uint32_t totalElements() const { return numberOfBlocks * elementsPerBlock; } + [[nodiscard]] constexpr u32 totalElements() const { return numberOfBlocks * elementsPerBlock; } - [[nodiscard]] T blockTolerance(const uint32_t block) const { - return (std::abs(scalesData[block]) * static_cast(0.5)) + (std::numeric_limits::epsilon() * static_cast(16)); + [[nodiscard]] T blockTolerance(const u32 block) const { + return (std::abs(scalesData[block]) * static_cast(0.5)) + ( + std::numeric_limits::epsilon() * static_cast(16)); } - explicit TestSet(const uint32_t number_of_blocks, const SeedType initial_seed = testSeed()) + explicit TestSet(const u32 number_of_blocks, const SeedType initial_seed = testSeed()) : seed(initial_seed), gen(seed), numberOfBlocks(number_of_blocks) { compressedData.resize(numberOfBlocks); sourceData.resize(number_of_blocks); @@ -59,32 +62,32 @@ class TestSet { generateData(); } - [[nodiscard]] const std::vector& getScales() const { return scalesData; } + [[nodiscard]] const std::vector &getScales() const { return scalesData; } - [[nodiscard]] const std::vector >& getCompressedData() const { return compressedData; } + [[nodiscard]] const std::vector > &getCompressedData() const { return compressedData; } - [[nodiscard]] const std::vector >& getDecompressedData() const { return sourceData; } + [[nodiscard]] const std::vector > &getDecompressedData() const { return sourceData; } [[nodiscard]] SeedType getSeed() const { return seed; } [[nodiscard]] static SeedType testSeed() { - const char* env_seed = std::getenv("PERNIX_TEST_SEED"); + const char *env_seed = std::getenv("PERNIX_TEST_SEED"); if (env_seed == nullptr || *env_seed == '\0') { return defaultSeed; } - char* end = nullptr; + char *end = nullptr; const unsigned long value = std::strtoul(env_seed, &end, 0); return (end != env_seed && *end == '\0') ? static_cast(value) : defaultSeed; } private: void generateData() { - for (uint32_t i = 0; i < numberOfBlocks; i++) { + for (u32 i = 0; i < numberOfBlocks; i++) { compressedData[i].resize(BLOCK_SIZE); sourceData[i].resize(elementsPerBlock); - for (uint32_t j = 0; j < elementsPerBlock; j++) { + for (u32 j = 0; j < elementsPerBlock; j++) { sourceData[i][j] = dis(gen); } @@ -97,133 +100,138 @@ class TestSet { if constexpr (std::is_same_v) { pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BIT_WIDTH, BLOCK_SIZE, - sourceData[i].data(), 1.0f / scalesData[i], compressedData[i].data()); + sourceData[i].data(), 1.0f / scalesData[i], compressedData[i].data()); } else { pernix_compress_block_f64(PERNIX_BACKEND_FALLBACK, BIT_WIDTH, BLOCK_SIZE, - sourceData[i].data(), 1.0 / scalesData[i], compressedData[i].data()); + sourceData[i].data(), 1.0 / scalesData[i], compressedData[i].data()); } } } }; -template +template struct BitWidthBlockSize { - static constexpr uint8_t bit_width = BIT_WIDTH; - static constexpr uint32_t block_size = BLOCK_SIZE; + static constexpr u8 bit_width = BIT_WIDTH; + static constexpr u32 block_size = BLOCK_SIZE; }; using testing::Types; using BitWidthBlockSizeTypes = Types, BitWidthBlockSize<2, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<3, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<4, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<5, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<6, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<7, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<8, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<9, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<10, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<11, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<12, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<13, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<14, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<15, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<16, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<17, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<18, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<19, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<20, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<21, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<22, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<23, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<24, PERNIX_TEST_BLOCK_SIZE> >; + BitWidthBlockSize<3, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<4, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<5, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<6, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<7, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<8, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<9, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<10, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<11, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<12, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<13, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<14, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<15, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<16, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<17, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<18, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<19, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<20, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<21, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<22, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<23, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<24, PERNIX_TEST_BLOCK_SIZE> >; struct BitWidthBlockSizeName { - template + template static std::string GetName(int) { std::ostringstream name; - name << "BitWidth" << static_cast(TestConfigT::bit_width) << "BlockSize" << TestConfigT::block_size; + name << "BitWidth" << static_cast(TestConfigT::bit_width) << "BlockSize" << TestConfigT::block_size; return name.str(); } }; -template +template class CompressionTest : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; CompressionTest() : testSet(1u << 10) { } }; -template +template class DecompressionTest : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; DecompressionTest() : testSet(1u << 10) { } }; -template +template class CompressionTest64 : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; CompressionTest64() : testSet(1u << 10) { } }; -template +template class DecompressionTest64 : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; DecompressionTest64() : testSet(1u << 10) { } }; -template -[[nodiscard]] std::string testContext(const FixtureT& fixture, const uint32_t block) { +template +[[nodiscard]] std::string testContext(const FixtureT &fixture, const u32 block) { std::ostringstream message; - message << "bit_width=" << static_cast(FixtureT::BitWidth) << ", block_size=" << FixtureT::BlockSize - << ", block=" << block << ", scale=" << fixture.testSet.getScales()[block] - << ", tolerance=" << fixture.testSet.blockTolerance(block) << ", seed=" << fixture.testSet.getSeed(); + message << "bit_width=" << static_cast(FixtureT::BitWidth) << ", block_size=" << FixtureT::BlockSize + << ", block=" << block << ", scale=" << fixture.testSet.getScales()[block] + << ", tolerance=" << fixture.testSet.blockTolerance(block) << ", seed=" << fixture.testSet.getSeed(); return message.str(); } -template -void expectCompressedBlockEqualsReference(const FixtureT& fixture, const std::vector& actual, const uint32_t block) { +template +void expectCompressedBlockEqualsReference(const FixtureT &fixture, const std::vector &actual, + const u32 block) { SCOPED_TRACE(testContext(fixture, block)); - const auto& expected = fixture.testSet.getCompressedData()[block]; + const auto &expected = fixture.testSet.getCompressedData()[block]; ASSERT_EQ(actual.size(), expected.size()) << "Compressed block byte count differs from reference"; - for (uint32_t byte = 0; byte < actual.size(); byte++) { + for (u32 byte = 0; byte < actual.size(); byte++) { ASSERT_EQ(actual[byte], expected[byte]) - << "Compressed byte mismatch at byte=" << byte << ", actual=" << static_cast(actual[byte]) - << ", expected=" << static_cast(expected[byte]); + << "Compressed byte mismatch at byte=" << byte << ", actual=" << static_cast(actual[byte]) + << ", expected=" << static_cast(expected[byte]); } } -template -void expectDecompressedBlockNearSource(const FixtureT& fixture, const std::vector& actual, const uint32_t block) { +template +void expectDecompressedBlockNearSource(const FixtureT &fixture, const std::vector &actual, const u32 block) { SCOPED_TRACE(testContext(fixture, block)); - const auto& expected = fixture.testSet.getDecompressedData()[block]; + const auto &expected = fixture.testSet.getDecompressedData()[block]; ASSERT_EQ(actual.size(), expected.size()) << "Decompressed block element count differs from source"; - for (uint32_t element = 0; element < actual.size(); element++) { + for (u32 element = 0; element < actual.size(); element++) { ASSERT_NEAR(actual[element], expected[element], fixture.testSet.blockTolerance(block)) - << "Decompressed element mismatch at element=" << element << ", actual=" << actual[element] - << ", expected=" << expected[element] << ", absolute_error=" << std::abs(actual[element] - expected[element]); + << "Decompressed element mismatch at element=" << element << ", actual=" << actual[element] + << ", expected=" << expected[element] << ", absolute_error=" << std::abs( + actual[element] - expected[element]); } } TYPED_TEST_SUITE(CompressionTest, BitWidthBlockSizeTypes, BitWidthBlockSizeName); + TYPED_TEST_SUITE(DecompressionTest, BitWidthBlockSizeTypes, BitWidthBlockSizeName); + TYPED_TEST_SUITE(CompressionTest64, BitWidthBlockSizeTypes, BitWidthBlockSizeName); + TYPED_TEST_SUITE(DecompressionTest64, BitWidthBlockSizeTypes, BitWidthBlockSizeName); #endif // PERNIX_TESTSET_H diff --git a/tests/simd_tests.cpp b/tests/simd_tests.cpp index 6f9bc50..019aea3 100644 --- a/tests/simd_tests.cpp +++ b/tests/simd_tests.cpp @@ -7,50 +7,50 @@ // SIMD compress: compress via backend, decompress via fallback, compare source // --------------------------------------------------------------------------- -template -void testBackendCompressBlock(FixtureT& fixture, pernix_backend backend) { +template +void testBackendCompressBlock(FixtureT &fixture, pernix_backend backend) { using T = std::remove_cvref_t; { - std::vector probe(FixtureT::BlockSize); + std::vector probe(FixtureT::BlockSize); pernix_status st; if constexpr (std::is_same_v) { st = pernix_compress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getDecompressedData()[0].data(), - 1.0f / fixture.testSet.getScales()[0], probe.data()); + fixture.testSet.getDecompressedData()[0].data(), + 1.0f / fixture.testSet.getScales()[0], probe.data()); } else { st = pernix_compress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getDecompressedData()[0].data(), - 1.0 / fixture.testSet.getScales()[0], probe.data()); + fixture.testSet.getDecompressedData()[0].data(), + 1.0 / fixture.testSet.getScales()[0], probe.data()); } if (st != PERNIX_STATUS_OK) { - SUCCEED(); + GTEST_SKIP(); return; } } - for (uint32_t b = 0; b < fixture.testSet.numberOfBlocks; b++) { - std::vector compressed(FixtureT::BlockSize); + for (u32 b = 0; b < fixture.testSet.numberOfBlocks; b++) { + std::vector compressed(FixtureT::BlockSize); pernix_status st; if constexpr (std::is_same_v) { st = pernix_compress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getDecompressedData()[b].data(), - 1.0f / fixture.testSet.getScales()[b], compressed.data()); + fixture.testSet.getDecompressedData()[b].data(), + 1.0f / fixture.testSet.getScales()[b], compressed.data()); } else { st = pernix_compress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getDecompressedData()[b].data(), - 1.0 / fixture.testSet.getScales()[b], compressed.data()); + fixture.testSet.getDecompressedData()[b].data(), + 1.0 / fixture.testSet.getScales()[b], compressed.data()); } ASSERT_EQ(st, PERNIX_STATUS_OK); std::vector restored(fixture.testSet.elementsPerBlock); if constexpr (std::is_same_v) { st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, FixtureT::BitWidth, FixtureT::BlockSize, - compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); + compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); } else { st = pernix_decompress_block_f64(PERNIX_BACKEND_FALLBACK, FixtureT::BitWidth, FixtureT::BlockSize, - compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); + compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); } ASSERT_EQ(st, PERNIX_STATUS_OK); @@ -62,8 +62,8 @@ void testBackendCompressBlock(FixtureT& fixture, pernix_backend backend) { // SIMD decompress: decompress fallback-compressed data via backend, compare source // --------------------------------------------------------------------------- -template -void testBackendDecompressBlock(FixtureT& fixture, pernix_backend backend) { +template +void testBackendDecompressBlock(FixtureT &fixture, pernix_backend backend) { using T = std::remove_cvref_t; { @@ -71,31 +71,31 @@ void testBackendDecompressBlock(FixtureT& fixture, pernix_backend backend) { pernix_status st; if constexpr (std::is_same_v) { st = pernix_decompress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getCompressedData()[0].data(), - fixture.testSet.getScales()[0], probe.data(), true); + fixture.testSet.getCompressedData()[0].data(), + fixture.testSet.getScales()[0], probe.data(), true); } else { st = pernix_decompress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getCompressedData()[0].data(), - fixture.testSet.getScales()[0], probe.data(), true); + fixture.testSet.getCompressedData()[0].data(), + fixture.testSet.getScales()[0], probe.data(), true); } if (st != PERNIX_STATUS_OK) { - SUCCEED(); + GTEST_SKIP(); return; } } - for (uint32_t b = 0; b < fixture.testSet.numberOfBlocks; b++) { + for (u32 b = 0; b < fixture.testSet.numberOfBlocks; b++) { std::vector decompressed(fixture.testSet.elementsPerBlock); pernix_status st; if constexpr (std::is_same_v) { st = pernix_decompress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getCompressedData()[b].data(), - fixture.testSet.getScales()[b], decompressed.data(), true); + fixture.testSet.getCompressedData()[b].data(), + fixture.testSet.getScales()[b], decompressed.data(), true); } else { st = pernix_decompress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, - fixture.testSet.getCompressedData()[b].data(), - fixture.testSet.getScales()[b], decompressed.data(), true); + fixture.testSet.getCompressedData()[b].data(), + fixture.testSet.getScales()[b], decompressed.data(), true); } ASSERT_EQ(st, PERNIX_STATUS_OK);