diff --git a/.clangd b/.clangd new file mode 100644 index 0000000..e69de29 diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index ae538ec..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "external/simde"] - path = external/simde - url = https://github.com/simd-everywhere/simde diff --git a/CMakeLists.txt b/CMakeLists.txt index edf6c78..c6af942 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.30) +cmake_minimum_required(VERSION 3.24) project(pernix VERSION 0.1.0 LANGUAGES CXX C) option(PERNIX_ENABLE_TESTS "Enable tests for pernix" on) @@ -8,9 +8,11 @@ option(PERNIX_ENABLE_INSTALL "Enable installation of pernix" on) option(PERNIX_ENABLE_DOXYGEN "Build documentation with Doxygen." off) option(PERNIX_INSTALL_DOCS "Install documentation with pernix" off) -option(PERNIX_DISABLE_BMI2 "Disable BMI2 optimizations" off) -option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) -option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) +option(PERNIX_ENABLE_X86_BMI2 "Build x86 BMI2 backend" ON) +option(PERNIX_ENABLE_X86_AVX2 "Build x86 AVX2 backend" ON) +option(PERNIX_ENABLE_X86_AVX512VBMI "Build x86 AVX512-VBMI backend" ON) +option(PERNIX_ENABLE_ARM64_NEON "Build arm64 NEON backend" ON) +option(PERNIX_ENABLE_ARM64_SVE2 "Build arm64 SVE2 backend" ON) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) @@ -18,9 +20,37 @@ option(PERNIX_ENABLE_FORTRAN_BINDINGS "Build Fortran bindings for pernix" off) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") +#[===[ +set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL OFF) if (PERNIX_USE_SIMDE) - add_subdirectory(external/simde EXCLUDE_FROM_ALL) + if (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "PACKAGE") + find_package(simde CONFIG QUIET) + endif () + + if (NOT TARGET simde::simde AND (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "FETCH")) + include(FetchContent) + set(SIMDE_TEST_CMAKE_PACKAGING OFF CACHE BOOL "Test SIMDe CMake packaging" FORCE) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG f3e8262173b7089db9a9d57a9ecef8dd07ad9c97 + GIT_PROGRESS TRUE + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(simde) + set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL ON) + endif () + + if (NOT TARGET simde::simde AND DEFINED simde_SOURCE_DIR AND EXISTS "${simde_SOURCE_DIR}/simde") + add_library(simde::simde INTERFACE IMPORTED GLOBAL) + target_include_directories(simde::simde INTERFACE "${simde_SOURCE_DIR}") + endif () + + if (NOT TARGET simde::simde) + message(FATAL_ERROR "PERNIX_USE_SIMDE is enabled, but simde::simde was not found. Set PERNIX_SIMDE_PROVIDER=FETCH or install SIMDe's CMake package.") + endif () endif () +]===] include(CTest) include(GitVersion) @@ -40,51 +70,73 @@ else () endif () message(STATUS "Pernix version: ${VERSION}, normalized to ${NORMALIZED_VERSION}") -set(BENCHMARK_CXX_STANDARD 20) +if (NOT CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|Intel") + message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") +endif () -set(CMAKE_CXX_STANDARD ${BENCHMARK_CXX_STANDARD}) -set(CMAKE_CXX_STANDARD_REQUIRED YES) -set(CMAKE_CXX_EXTENSIONS OFF) +include(CheckCXXCompilerFlag) +set(PERNIX_PRIVATE_COMPILE_OPTIONS) +foreach (PERNIX_CXX_FLAG + -Wall + -Wextra + -Wshadow + -Wfloat-equal + -Wold-style-cast + -Wconversion + -fstrict-aliasing + -Wno-ignored-attributes +) + string(MAKE_C_IDENTIFIER "PERNIX_HAS_CXX_FLAG_${PERNIX_CXX_FLAG}" PERNIX_CXX_FLAG_VARIABLE) + check_cxx_compiler_flag("${PERNIX_CXX_FLAG}" "${PERNIX_CXX_FLAG_VARIABLE}") + if (${PERNIX_CXX_FLAG_VARIABLE}) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "${PERNIX_CXX_FLAG}") + else () + message(STATUS "Compiler flag not supported: ${PERNIX_CXX_FLAG}") + endif () +endforeach () -include(AddCXXCompilerFlag) -if (MSVC) - message(FATAL_ERROR "MSVC compiler is not supported") -else () - add_cxx_compiler_flag(-Wall) - add_cxx_compiler_flag(-Wextra) - add_cxx_compiler_flag(-Wshadow) - add_cxx_compiler_flag(-Wfloat-equal) - add_cxx_compiler_flag(-Wold-style-cast) - add_cxx_compiler_flag(-Wconversion) - add_cxx_compiler_flag(-fstrict-aliasing) - add_cxx_compiler_flag(-Wno-ignored-attributes) - - if (PERNIX_ENABLE_LTO) - add_cxx_compiler_flag(-flto=auto) - add_cxx_compiler_flag(-Wno-lto-type-mismatch) - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - find_program(GCC_AR gcc-ar) - if (GCC_AR) - set(CMAKE_AR ${GCC_AR}) - endif () - find_program(GCC_RANLIB gcc-ranlib) - if (GCC_RANLIB) - set(CMAKE_RANLIB ${GCC_RANLIB}) - endif () - elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") - find_program(LLVM_AR llvm-ar) - if (LLVM_AR) - set(CMAKE_AR ${LLVM_AR}) - endif () - find_program(LLVM_RANLIB llvm-ranlib) - if (LLVM_RANLIB) - set(CMAKE_RANLIB ${LLVM_RANLIB}) - endif () +if (PERNIX_ENABLE_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT PERNIX_IPO_SUPPORTED OUTPUT PERNIX_IPO_ERROR) + if (NOT PERNIX_IPO_SUPPORTED) + message(FATAL_ERROR "PERNIX_ENABLE_LTO is enabled, but IPO/LTO is not supported: ${PERNIX_IPO_ERROR}") + endif () + + check_cxx_compiler_flag("-Wno-lto-type-mismatch" PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + if (PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "-Wno-lto-type-mismatch") + endif () + + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + find_program(GCC_AR gcc-ar) + if (GCC_AR) + set(CMAKE_AR ${GCC_AR}) + endif () + find_program(GCC_RANLIB gcc-ranlib) + if (GCC_RANLIB) + set(CMAKE_RANLIB ${GCC_RANLIB}) + endif () + elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + find_program(LLVM_AR llvm-ar) + if (LLVM_AR) + set(CMAKE_AR ${LLVM_AR}) + endif () + find_program(LLVM_RANLIB llvm-ranlib) + if (LLVM_RANLIB) + set(CMAKE_RANLIB ${LLVM_RANLIB}) endif () endif () endif () -include_directories(${PROJECT_SOURCE_DIR}/include) +set(PERNIX_TARGET_IS_X86 OFF) +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_TARGET_IS_X86 ON) +endif () + +set(PERNIX_TARGET_IS_ARM64 OFF) +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_TARGET_IS_ARM64 ON) +endif () add_subdirectory(src) @@ -97,4 +149,4 @@ endif () if (PERNIX_ENABLE_TESTS) enable_testing() add_subdirectory(tests) -endif () \ No newline at end of file +endif () diff --git a/cmake/pernixConfig.cmake.in b/cmake/pernixConfig.cmake.in new file mode 100644 index 0000000..fe3eb88 --- /dev/null +++ b/cmake/pernixConfig.cmake.in @@ -0,0 +1,18 @@ +@PACKAGE_INIT@ + +if (@PERNIX_USE_SIMDE@) + find_package(simde CONFIG QUIET) + if (NOT TARGET simde::simde AND @PERNIX_BUNDLE_SIMDE_FOR_INSTALL@) + add_library(simde::simde INTERFACE IMPORTED) + target_include_directories(simde::simde INTERFACE "${PACKAGE_PREFIX_DIR}/include") + endif () + if (NOT TARGET simde::simde) + set(pernix_FOUND FALSE) + set(pernix_NOT_FOUND_MESSAGE "pernix was built with SIMDe support, but simde::simde was not found. Install SIMDe's CMake package or use a Pernix package built with bundled SIMDe headers.") + return() + endif () +endif () + +include("${CMAKE_CURRENT_LIST_DIR}/pernixTargets.cmake") + +check_required_components(pernix) diff --git a/external/simde b/external/simde deleted file mode 160000 index 1747b24..0000000 --- a/external/simde +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1747b2482589fe894d49989159421da08c2a8bcd diff --git a/include/pernix/compat.h b/include/pernix/compat.h new file mode 100644 index 0000000..498c83a --- /dev/null +++ b/include/pernix/compat.h @@ -0,0 +1,45 @@ +#ifndef PERNIX_COMPAT_H +#define PERNIX_COMPAT_H + +#include +#include +#include + +#ifndef __always_inline +#if defined(__GNUC__) || defined(__clang__) +#define __always_inline inline __attribute__((always_inline)) +#elif defined(_MSC_VER) +#define __always_inline __forceinline +#else +#define __always_inline inline +#endif +#endif + +#if defined(_WIN32) && defined(PERNIX_SHARED) +#if defined(PERNIX_BUILD_LIB) +#define PERNIX_API __declspec(dllexport) +#else +#define PERNIX_API __declspec(dllimport) +#endif +#else +#define PERNIX_API +#endif + +// Convenient type declarations +typedef uint8_t u8; +typedef uint16_t u16; +typedef uint32_t u32; +typedef uint64_t u64; +typedef uintptr_t uptr; + +typedef int8_t i8; +typedef int16_t i16; +typedef int32_t i32; +typedef int64_t i64; + +typedef float_t f32; +typedef double_t f64; + +typedef size_t usize; + +#endif //PERNIX_COMPAT_H diff --git a/include/pernix/detection.h b/include/pernix/detection.h deleted file mode 100644 index edecb6c..0000000 --- a/include/pernix/detection.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef PERNIX_HELPER_H -#define PERNIX_HELPER_H - -#include - -// Internal capability tiers used to derive the public PERNIX_* feature macros. -#define PERNIX_MACHINE_ID_GENERIC 0 -#define PERNIX_MACHINE_ID_V2 1 -#define PERNIX_MACHINE_ID_V3 2 -#define PERNIX_MACHINE_ID_V4 3 -#define PERNIX_MACHINE_ID_V4_VBMI 4 - -// Map the compiler's enabled ISA set to the highest supported Pernix target level. -#if (__SSE3__ && __SSE4_1__ && __SSE4_2__) -#if (__AVX__ && __AVX2__ && __FMA__ && __BMI__ && __BMI2__) -#if (__AVX512BW__ && __AVX512CD__ && __AVX512DQ__ && __AVX512F__ && __AVX512VL__) -#if (__AVX512VBMI__) -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V4_VBMI -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V4 -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V3 -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_V2 -#endif - -#else -#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC -#endif - -// Feature-selection macros consumed by the public headers. -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V2) -#define PERNIX_SSE_ENABLED -#endif -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V3) -#define PERNIX_AVX2_ENABLED -#define PERNIX_BMI2_ENABLED -#endif -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V4) -#define PERNIX_AVX512_ENABLED -#endif -#if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V4_VBMI) -#define PERNIX_AVX512_VBMI_ENABLED -#endif - -#ifdef PERNIX_USE_SIMDE -#define PERNIX_SSE_ENABLED -#define PERNIX_AVX2_ENABLED -#define PERNIX_BMI2_ENABLED -#define PERNIX_AVX512_ENABLED -#define PERNIX_AVX512_VBMI_ENABLED -#endif - -// Allow build systems or tests to force lower-tier implementations. -#ifdef PERNIX_DISABLE_AVX512 -#undef PERNIX_AVX512_ENABLED -#undef PERNIX_AVX512_VBMI_ENABLED -#endif -#ifdef PERNIX_DISABLE_AVX2 -#undef PERNIX_AVX2_ENABLED -#endif -#ifdef PERNIX_DISABLE_BMI2 -#undef PERNIX_BMI2_ENABLED -#endif - -#endif // PERNIX_HELPER_H diff --git a/include/pernix/fallback/compression.h b/include/pernix/fallback/compression.h deleted file mode 100644 index 5b2780b..0000000 --- a/include/pernix/fallback/compression.h +++ /dev/null @@ -1,306 +0,0 @@ -#ifndef PERNIX_FALLBACK_COMPRESSION_H -#define PERNIX_FALLBACK_COMPRESSION_H - -#include - -#include -#include -#include -#include -#include -#include - -namespace pernix { -namespace internal { -/** - * @brief Quantize a single float value to int32_t using the provided scale. - * - * @param input input float value to be quantized. - * @param scale scaling factor used during quantization. - * @return int32_t quantized integer value. - */ -__always_inline int32_t quantize_ps_epi32(const float input, const float scale) { - return static_cast(std::lroundf(input * scale)); -} - -/** - * @brief Quantize a single double value to int64_t using the provided scale. - * - * @param input input double value to be quantized. - * @param scale scaling factor used during quantization. - * @return int64_t quantized integer value. - */ -__always_inline int64_t quantize_pd_epi64(const double_t input, const double_t scale) { - return std::llround(input * scale); -} - -/** - * @brief Quantize and clamp without narrowing through an out-of-range integer type. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_floating_point_v) -__always_inline int32_t quantize_clamped(const T input, const T scale) { - constexpr int64_t min_value = BIT_WIDTH == 1 ? 0 : -(int64_t{1} << (BIT_WIDTH - 1)); - constexpr int64_t max_value = BIT_WIDTH == 1 ? 1 : ((int64_t{1} << (BIT_WIDTH - 1)) - 1); - - const long double scaled = static_cast(input) * static_cast(scale); - if (std::isnan(scaled)) { - return 0; - } - if (scaled <= static_cast(min_value)) { - return static_cast(min_value); - } - if (scaled >= static_cast(max_value)) { - return static_cast(max_value); - } - - return static_cast(std::llround(scaled)); -} - -/** - * @brief Clamp a signed quantized value to the representable range of BIT_WIDTH bits. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline int32_t clamp_signed_quantized(const int64_t value) { - if constexpr (BIT_WIDTH == 1) { - // 1-bit fallback is treated as binary quantization (0/1). - return static_cast(std::clamp(value, 0, 1)); - } - - constexpr int32_t min_value = -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = (1 << (BIT_WIDTH - 1)) - 1; - return static_cast(std::clamp(value, min_value, max_value)); -} - -/** - * @brief Append packed scalar values into an output buffer using the selected - * storage width. - * - * @tparam T unsigned integer type used as the packing word. - * @tparam BIT_WIDTH bit width per value in the packed representation. - * @param input vector of quantized values to pack. - * @param bit_offset starting bit offset in the destination buffer. - * @param destination pointer to the output buffer. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) -void pack_epi32_fallback_inner(const std::vector& input, const uint8_t bit_offset, uint8_t* __restrict__ destination) { - constexpr uint32_t bits_in_type = sizeof(T) * 8; - constexpr uint32_t bitmask = BIT_WIDTH == bits_in_type ? std::numeric_limits::max() : (1U << BIT_WIDTH) - 1U; - - std::size_t idx = 0; - std::size_t bits_in_buffer = bit_offset; - uint64_t buffer = bit_offset ? static_cast(destination[0] & ((1U << bit_offset) - 1U)) : 0; - -#pragma GCC unroll 64 - for (uint32_t raw_value : input) { - const uint32_t next_value = raw_value & bitmask; - - buffer |= static_cast(next_value) << bits_in_buffer; - bits_in_buffer += BIT_WIDTH; - - while (bits_in_buffer >= 8) { - destination[idx++] = static_cast(buffer & 0xFFU); - buffer >>= 8; - bits_in_buffer -= 8; - } - } - - if (bits_in_buffer > 0) { - destination[idx] = static_cast(buffer & 0xFFU); - } -} - -/** - * @brief Pack a vector of uint32_t values into a compact byte representation using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). - * - * @param input vector of uint32_t values to be packed. - * @param destination pointer to the output buffer where packed bytes will be stored. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -void pack_epi32_fallback(const std::vector& input, uint8_t* __restrict__ destination) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::pack_epi32_fallback_inner(input, 0, destination); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::pack_epi32_fallback_inner(input, 0, destination); - } else { - return internal::pack_epi32_fallback_inner(input, 0, destination); - } -} -} // namespace internal - -/** - * @brief Compress a single 512-bit block using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_block_fallback(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - std::memset(output, 0, BLOCK_SIZE); - - std::vector block_values(elements_per_block); -#pragma GCC unroll 64 - for (uint32_t i = 0; i < elements_per_block; i++) { - const int32_t quantized = internal::quantize_clamped(input[i], scale); - block_values[i] = static_cast(quantized); - } - - internal::pack_epi32_fallback(block_values, output); - return 0; -} - -/** - * @brief Compress a single block of double values using the fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_block_fallback(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - std::memset(output, 0, BLOCK_SIZE); - - std::vector block_values(elements_per_block); -#pragma GCC unroll 32 - for (uint32_t i = 0; i < elements_per_block; i++) { - const int32_t quantized = internal::quantize_clamped(input[i], scale); - block_values[i] = static_cast(quantized); - } - - internal::pack_epi32_fallback(block_values, output); - return 0; -} - -/** - * @brief Compress multiple 512-bit blocks using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_blocks_fallback(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - compress_block_fallback(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} - -/** - * @brief Compress multiple blocks of double values using the fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int compress_blocks_fallback(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - compress_block_fallback(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - */ -int compress_block_fallback(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - */ -int compress_block_fallback_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - */ -int compress_blocks_fallback(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - */ -int compress_blocks_fallback_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_FALLBACK_COMPRESSION_H diff --git a/include/pernix/fallback/decompression.h b/include/pernix/fallback/decompression.h deleted file mode 100644 index c4714e3..0000000 --- a/include/pernix/fallback/decompression.h +++ /dev/null @@ -1,286 +0,0 @@ -#ifndef PERNIX_FALLBACK_DECOMPRESSION_H -#define PERNIX_FALLBACK_DECOMPRESSION_H - -#include - -#include -#include -#include -#include - -namespace pernix { -namespace internal { -/** - * @brief Dequantize a single int32_t value to float using the provided scale. - * - * @param input input int32_t value to be dequantized. - * @param scale scaling factor used during quantization. - * @return float dequantized float value. - */ -__always_inline float dequantize_epi32(const int32_t input, const float scale) { - return static_cast(input) * scale; -} - -/** - * @brief Dequantize a single int64_t value to double using the provided scale. - * - * @param input input int64_t value to be dequantized. - * @param scale scaling factor used during quantization. - * @return double_t dequantized double value. - */ -__always_inline double_t dequantize_epi64(const int64_t input, const double_t scale) { - return static_cast(input) * scale; -} - -/** - * @brief Sign-extend a packed integer value stored in the low bits of a 32-bit word. - * - * @tparam BIT_WIDTH number of significant bits in the encoded value. - * @param value unsigned packed value. - * @return int32_t sign-extended value. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline auto sign_extend(const uint32_t value) -> int32_t { - if constexpr (BIT_WIDTH == 1) { - return static_cast(value & 1U); - } - - constexpr uint32_t sign_bit = uint32_t{1} << (BIT_WIDTH - 1); - constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1; - const uint32_t masked = value & mask; - return static_cast((static_cast(masked ^ sign_bit)) - static_cast(sign_bit)); -} - -/** - * @brief Unpack bit-packed values from a typed input span into signed 32-bit integers. - * - * @tparam T unsigned integer type used to read the source buffer. - * @tparam BIT_WIDTH bit width per packed value. - * @tparam SIGN_VALUES whether to sign-extend unpacked values. - * @param input pointer to the typed packed input buffer. - * @param bit_offset starting bit offset in the first input word. - * @param elements number of values to unpack. - * @return std::vector unpacked values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) -__always_inline auto unpack_epi32_fallback_inner(const uint8_t* __restrict__ input, const uint8_t bit_offset, const std::size_t elements) - -> std::vector { - constexpr uint32_t bits_in_type = sizeof(T) * 8; - constexpr uint32_t bitmask = BIT_WIDTH == bits_in_type ? std::numeric_limits::max() : (1U << BIT_WIDTH) - 1U; - - std::vector output(elements); - - std::size_t idx = 0; - uint8_t bits_in_buffer = 8 - bit_offset; - uint64_t buffer = static_cast(input[idx++]) >> bit_offset; - -#pragma GCC unroll 64 - for (uint32_t i = 0; i < elements; i++) { - while (BIT_WIDTH > bits_in_buffer) { - const auto next_value = static_cast(input[idx++]) << bits_in_buffer; - buffer |= next_value; - bits_in_buffer += 8; - } - - const uint32_t raw_value = static_cast(buffer & bitmask); - if constexpr (SIGN_VALUES) { - output[i] = sign_extend(raw_value); - } else { - output[i] = static_cast(raw_value); - } - - buffer >>= BIT_WIDTH; - bits_in_buffer -= BIT_WIDTH; - } - - return output; -} - -/** - * @brief Unpack packed int32_t values from the input buffer using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the packed data. - * @param elements number of elements to unpack. - * @return std::vector unpacked int32_t values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline auto unpack_epi32_fallback(const uint8_t* __restrict__ input, const std::size_t elements) -> std::vector { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return unpack_epi32_fallback_inner(input, 0, elements); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return unpack_epi32_fallback_inner(input, 0, elements); - } else { - return unpack_epi32_fallback_inner(input, 0, elements); - } -} -} // namespace internal - -/** - * @brief Decompress a single 512\-bit block using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_block_fallback(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const std::vector block_values = internal::unpack_epi32_fallback(input, elements_per_block); - -#pragma GCC unroll 512 - for (uint32_t i = 0; i < elements_per_block; i++) { - output[i] = internal::dequantize_epi32(block_values[i], scale); - } - - return 0; -} - -/** - * @brief Decompress a single block to double values using the fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_block_fallback(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const std::vector block_values = internal::unpack_epi32_fallback(input, elements_per_block); - -#pragma GCC unroll 512 - for (uint32_t i = 0; i < elements_per_block; i++) { - output[i] = internal::dequantize_epi64(block_values[i], scale); - } - - return 0; -} - -/** - * @brief Decompress multiple 512\-bit blocks using fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -int decompress_blocks_fallback(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - decompress_block_fallback(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -/** - * @brief Decompress multiple blocks to double values using the fallback scalar implementation. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of blocks to decompress. - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -int decompress_blocks_fallback(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - decompress_block_fallback(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Decompress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - */ -int decompress_block_fallback(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - */ -int decompress_block_fallback_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - */ -int decompress_blocks_fallback(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks using fallback scalar implementation. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - */ -int decompress_blocks_fallback_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_FALLBACK_DECOMPRESSION_H diff --git a/include/pernix/pernix.h b/include/pernix/pernix.h index 55133bb..74d586b 100644 --- a/include/pernix/pernix.h +++ b/include/pernix/pernix.h @@ -1,423 +1,64 @@ #ifndef PERNIX_H #define PERNIX_H -#include +#include -// Include architecture-specific headers based on detected capabilities -// AVX2 -#ifdef PERNIX_AVX2_ENABLED -#include -#include - -// BMI2: Needs AVX2 as well -#ifdef PERNIX_BMI2_ENABLED -#include -#include -#endif // PERNIX_BMI2_ENABLED - -// AVX512 VBMI: Needs AVX2 as well -#ifdef PERNIX_AVX512_VBMI_ENABLED -#include -#include -#endif // PERNIX_AVX512_VBMI_ENABLED - -#endif // PERNIX_AVX2_ENABLED - -// Fallback (non-SIMD) implementations -#include -#include - -namespace pernix { -/** - * @brief Compress a single block of floating-point data into a bit-packed format using the specified bit width and scale. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * - * @return int status code (0 for success). - * - * @note This function will dispatch to the best available implementation based on detected CPU features at compile time. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single block of double-precision values into a bit-packed representation. - * - * @tparam BIT_WIDTH bit width per quantized value (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the input block. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple blocks of single-precision values. - * - * @tparam BIT_WIDTH bit width per quantized value (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, uint32_t blocks); - -/** - * @brief Compress multiple blocks of double-precision values. - * - * @tparam BIT_WIDTH bit width per quantized value (1 to 24). - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, uint32_t blocks); - -/** - * @brief Decompress a single block of packed values into single-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single block of packed values into double-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple blocks of packed values into single-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); - -/** - * @brief Decompress multiple blocks of packed values into double-precision values. - * - * @tparam BIT_WIDTH bit width per packed value (1 to 24). - * @tparam SIGN_VALUES true for signed values, false for unsigned values. - * @tparam BLOCK_SIZE size of each block in bytes (must be a multiple of 32). - * - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * - * @return int status code (0 for success). - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); - -// Use the best available implementation based on detected CPU features at compile time. -#ifdef PERNIX_AVX2_ENABLED -#ifdef PERNIX_AVX512_VBMI_ENABLED -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm512_compress_blocks_avx512vbmi(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm512_compress_blocks_avx512vbmi(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm512_decompress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm512_decompress_block_avx512vbmi(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return mm512_decompress_blocks_avx512vbmi(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return mm512_decompress_blocks_avx512vbmi(input, scale, output, blocks); -} -#else -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm256_compress_blocks_avx2(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return mm256_compress_blocks_avx2(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm256_decompress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm256_decompress_block_avx2(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return mm256_decompress_blocks_avx2(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return mm256_decompress_blocks_avx2(input, scale, output, blocks); -} -#endif -#else -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return compress_blocks_fallback(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { - return compress_blocks_fallback(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return decompress_block_fallback(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return decompress_block_fallback(input, scale, output); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - return decompress_blocks_fallback(input, scale, output, blocks); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - return decompress_blocks_fallback(input, scale, output, blocks); -} -#endif -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { +#if defined(__cplusplus) extern "C" { #endif -/** - * @brief C ABI wrapper for compressing one single-precision block. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the input block. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief C ABI wrapper for compressing one double-precision block. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the input block. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief C ABI wrapper for compressing multiple single-precision blocks. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, uint32_t blocks); - -/** - * @brief C ABI wrapper for compressing multiple double-precision blocks. - * - * @param bit_width bit width per quantized value (8 to 16). - * @param input pointer to the first input value. - * @param scale scaling factor used during quantization. - * @param output pointer to the destination compressed bytes. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief C ABI wrapper for decompressing one single-precision block. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief C ABI wrapper for decompressing one double-precision block. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the packed input block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief C ABI wrapper for decompressing multiple single-precision blocks. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); - -/** - * @brief C ABI wrapper for decompressing multiple double-precision blocks. - * - * @param bit_width bit width per packed value (1 to 24). - * @param input pointer to the first packed block. - * @param scale scaling factor used to reconstruct floating-point values. - * @param output pointer to the destination decompressed values. - * @param blocks number of consecutive blocks to process. - * @return int status code (0 for success, non-zero for invalid arguments or unsupported bit width). - */ -int decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus +typedef enum pernix_status { + PERNIX_STATUS_OK = 0, + PERNIX_STATUS_INVALID_ARGUMENT = -1, + PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH = -2, + PERNIX_STATUS_UNSUPPORTED_BACKEND = -3, + PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE = -4 +} pernix_status; + +typedef enum pernix_backend { + PERNIX_BACKEND_AUTO = 0, + PERNIX_BACKEND_FALLBACK = 1, + PERNIX_BACKEND_X86_AVX2 = 2, + PERNIX_BACKEND_X86_BMI2 = 3, + PERNIX_BACKEND_X86_AVX512_VBMI = 4, + PERNIX_BACKEND_ARM64_NEON = 5, + PERNIX_BACKEND_ARM64_SVE = 6 +} pernix_backend; + +PERNIX_API pernix_status pernix_compress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + float scale, void* output); + +PERNIX_API pernix_status pernix_compress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + float scale, void* output, u32 blocks); + +PERNIX_API pernix_status pernix_decompress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + float scale, void* output, bool sign_values); + +PERNIX_API pernix_status pernix_decompress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + float scale, void* output, u32 blocks, bool sign_values); + +PERNIX_API pernix_status pernix_compress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + double scale, void* output); + +PERNIX_API pernix_status pernix_compress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + double scale, void* output, u32 blocks); + +PERNIX_API pernix_status pernix_decompress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + double scale, void* output, bool sign_values); + +PERNIX_API pernix_status pernix_decompress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, + const void* input, + double scale, void* output, u32 blocks, bool sign_values); + +#if defined(__cplusplus) } -} // namespace pernix #endif -#endif // PERNIX_H \ No newline at end of file +#endif //PERNIX_H diff --git a/include/pernix/pernix.hpp b/include/pernix/pernix.hpp new file mode 100644 index 0000000..d197f72 --- /dev/null +++ b/include/pernix/pernix.hpp @@ -0,0 +1,166 @@ +#ifndef PERNIX_HPP +#define PERNIX_HPP + +#include +#include + +namespace pernix { +enum class Backend { + Auto = PERNIX_BACKEND_AUTO, + Fallback = PERNIX_BACKEND_FALLBACK, + X86Avx2 = PERNIX_BACKEND_X86_AVX2, + X86Bmi2 = PERNIX_BACKEND_X86_BMI2, + X86Avx512Vbmi = PERNIX_BACKEND_X86_AVX512_VBMI, + Arm64Neon = PERNIX_BACKEND_ARM64_NEON, + Arm64Sve = PERNIX_BACKEND_ARM64_SVE +}; + +__always_inline int compress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output) { + return pernix_compress_block_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data()); +} + +__always_inline int compress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output) { + return pernix_compress_block_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data()); +} + +__always_inline int decompress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output, + const bool sign_values = true) { + return pernix_decompress_block_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), + sign_values); +} + +__always_inline int decompress_block(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output, + const bool sign_values = true) { + return pernix_decompress_block_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), + sign_values); +} + +__always_inline int compress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output, + const u32 blocks) { + return pernix_compress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), + blocks); +} + +__always_inline int compress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output, + const u32 blocks) { + return pernix_compress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), + blocks); +} + +__always_inline int decompress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const float scale, std::span output, + const u32 blocks, const bool sign_values = true) { + return pernix_decompress_blocks_f32(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), + blocks, sign_values); +} + +__always_inline int decompress_blocks(Backend backend, const u8 bit_width, const u32 block_size, + const std::span input, const double scale, std::span output, + const u32 blocks, const bool sign_values = true) { + return pernix_decompress_blocks_f64(static_cast(backend), bit_width, block_size, input.data(), + scale, output.data(), + blocks, sign_values); +} + +// convenience overloads without backend (defaults to Auto) +__always_inline int compress_block(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output) { + return compress_block(Backend::Auto, bit_width, block_size, input, scale, output); +} + +__always_inline int compress_block(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output) { + return compress_block(Backend::Auto, bit_width, block_size, input, scale, output); +} + +__always_inline int decompress_block(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, block_size, input, scale, output, sign_values); +} + +__always_inline int decompress_block(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output, + const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, block_size, input, scale, output, sign_values); +} + +__always_inline int compress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output, const u32 blocks) { + return compress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks); +} + +__always_inline int compress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output, const u32 blocks) { + return compress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks); +} + +__always_inline int decompress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const float scale, const std::span output, const u32 blocks, + const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks, sign_values); +} + +__always_inline int decompress_blocks(const u8 bit_width, const u32 block_size, const std::span input, + const double scale, const std::span output, const u32 blocks, + const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, block_size, input, scale, output, blocks, sign_values); +} + +// convenience overloads without backend and without block_size (defaults to 64) +__always_inline int compress_block(const u8 bit_width, const std::span input, const float scale, + const std::span output) { + return compress_block(Backend::Auto, bit_width, 64, input, scale, output); +} + +__always_inline int compress_block(const u8 bit_width, const std::span input, const double scale, + const std::span output) { + return compress_block(Backend::Auto, bit_width, 64, input, scale, output); +} + +__always_inline int decompress_block(const u8 bit_width, const std::span input, const float scale, + const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, 64, input, scale, output, sign_values); +} + +__always_inline int decompress_block(const u8 bit_width, const std::span input, const double scale, + const std::span output, const bool sign_values = true) { + return decompress_block(Backend::Auto, bit_width, 64, input, scale, output, sign_values); +} + +__always_inline int compress_blocks(const u8 bit_width, const std::span input, const float scale, + const std::span output, const u32 blocks) { + return compress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks); +} + +__always_inline int compress_blocks(const u8 bit_width, const std::span input, const double scale, + const std::span output, const u32 blocks) { + return compress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks); +} + +__always_inline int decompress_blocks(const u8 bit_width, const std::span input, const float scale, + const std::span output, const u32 blocks, + const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks, sign_values); +} + +__always_inline int decompress_blocks(const u8 bit_width, const std::span input, const double scale, + const std::span output, const u32 blocks, + const bool sign_values = true) { + return decompress_blocks(Backend::Auto, bit_width, 64, input, scale, output, blocks, sign_values); +} +} + +#endif //PERNIX_HPP diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h deleted file mode 100644 index c95c1ee..0000000 --- a/include/pernix/simd_compat.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef PERNIX_SIMD_COMPAT_H -#define PERNIX_SIMD_COMPAT_H - -#include -#include - -#if defined(PERNIX_USE_SIMDE) -#define SIMDE_ENABLE_NATIVE_ALIASES -#undef SIMDE_X86_AVX512FP16_NATIVE -// #define SIMDE_NO_NATIVE -#include -#include -#include - -// #ifndef __mmask8 -// typedef uint8_t __mmask8; -// #endif -// #ifndef __mmask16 -// typedef uint16_t __mmask16; -// #endif -// #ifndef __mmask32 -// typedef uint32_t __mmask32; -// #endif -// #ifndef __mmask64 -// typedef uint64_t __mmask64; -// #endif - -#elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) -#include -#endif - -#ifndef __always_inline -#if defined(__GNUC__) || defined(__clang__) -#define __always_inline inline __attribute__((always_inline)) -#elif defined(_MSC_VER) -#define __always_inline __forceinline -#else -#define __always_inline inline -#endif -#endif - -template - requires(std::is_integral_v && sizeof(T) <= 8) -static constexpr T tail_mask(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - if (tail_bytes == 0u) { - return static_cast(0); - } - if (tail_bytes >= 64u) { - return static_cast(~uint64_t{0}); - } - const uint64_t mask = (uint64_t{1} << tail_bytes) - 1u; - return static_cast(mask); -} - -static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - return tail_bytes; -} - -#endif // PERNIX_SIMD_COMPAT_H diff --git a/include/pernix/x86/avx2/compression.h b/include/pernix/x86/avx2/compression.h deleted file mode 100644 index 357947e..0000000 --- a/include/pernix/x86/avx2/compression.h +++ /dev/null @@ -1,647 +0,0 @@ -#ifndef PERNIX_AVX2_COMPRESSION_H -#define PERNIX_AVX2_COMPRESSION_H - -#include -#include -#include - -#include -#include -#include -#include - -namespace pernix { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__always_inline __m256i mm256_clamp_signed_epi32(__m256i input) { - constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), _mm256_set1_epi32(max_value)); -} - -/** - * @brief Quantize four float values into signed 32-bit integers. - * - * @param input source float lane values. - * @param scale per-lane scale factor. - * @return __m128i quantized values. - */ -__always_inline __m128i mm_quantize_ps_epi32(const __m128& input, const __m128& scale) { - const __m128 scaled = _mm_mul_ps(input, scale); - // const __m128 rounded = _mm_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - return _mm_cvtps_epi32(scaled); -} - -/** - * @brief Quantize two double values into a partially filled 128-bit integer register. - * - * @param input source double lane values. - * @param scale per-lane scale factor. - * @return __m128i quantized values in the low lanes. - */ -__always_inline __m128i mm_quantize_pd_epi32(const __m128d& input, const __m128d& scale) { - const __m128d scaled = _mm_mul_pd(input, scale); - return _mm_cvtpd_epi32(scaled); -} - -/** - * @brief Quantize eight float values into signed 32-bit integers. - * - * @param input source float lane values. - * @param scale per-lane scale factor. - * @return __m256i quantized values. - */ -__always_inline __m256i mm256_quantize_ps_epi32(const __m256& input, const __m256& scale) { - const __m256 scaled = _mm256_mul_ps(input, scale); - // const __m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - return _mm256_cvtps_epi32(scaled); -} - -/** - * @brief Quantize four double values into signed 32-bit integers. - * - * @param input source double lane values. - * @param scale per-lane scale factor. - * @return __m128i quantized values in the low lanes. - */ -__always_inline __m128i mm256_quantize_pd_epi32(const __m256d& input, const __m256d& scale) { - const __m256d scaled = _mm256_mul_pd(input, scale); - return _mm256_cvtpd_epi32(scaled); -} - -#ifndef PERNIX_USE_SIMDE -/** - * @brief Emulate per-16-bit left shifts on AVX2. - * - * @param a source values. - * @param count per-lane shift amounts. - * @return __m128i shifted values. - */ -__always_inline static __m128i _mm_sllv_epi16(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi32(0xffff0000); - const __m128i low_half = _mm_sllv_epi32(a, _mm_andnot_si128(mask, count)); - const __m128i high_half = _mm_sllv_epi32(_mm_and_si128(mask, a), _mm_srli_epi32(count, 16)); - return _mm_blend_epi16(low_half, high_half, 0xaa); -} - -/** - * @brief Emulate per-16-bit right shifts on AVX2. - */ -__always_inline static __m128i _mm_srlv_epi16(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi32(0x0000ffff); - const __m128i low_half = _mm_srlv_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); - const __m128i high_half = _mm_srlv_epi32(a, _mm_srli_epi32(count, 16)); - return _mm_blend_epi16(low_half, high_half, 0xaa); -} - -/** - * @brief Emulate per-16-bit left shifts on 256-bit AVX2 vectors. - */ -__always_inline static __m256i _mm256_sllv_epi16(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi32(0xffff0000); - const __m256i low_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); - const __m256i high_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); - return _mm256_blend_epi16(low_half, high_half, 0xaa); -} - -/** - * @brief Emulate per-16-bit right shifts on 256-bit AVX2 vectors. - */ -__always_inline static __m256i _mm256_srlv_epi16(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi32(0x0000ffff); - const __m256i low_half = _mm256_srlv_epi32(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); - const __m256i high_half = _mm256_srlv_epi32(a, _mm256_srli_epi32(count, 16)); - return _mm256_blend_epi16(low_half, high_half, 0xaa); -} - -/** - * @brief Blend 8-bit lanes by expanding a scalar mask value. - */ -__always_inline static __m128i mm_blend_epi8(const __m128i X, const __m128i Y, const int8_t M) { - return _mm_blendv_epi8(X, Y, _mm_set1_epi8(M)); -} - -/** - * @brief Blend 8-bit lanes in 256-bit vectors by expanding a scalar mask value. - */ -__always_inline static __m256i mm256_blend_epi8(const __m256i X, const __m256i Y, const int8_t M) { - return _mm256_blendv_epi8(X, Y, _mm256_set1_epi8(M)); -} - -/** - * @brief Emulate per-byte left shifts on 128-bit vectors. - */ -__always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0xff00); - const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); - const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); - return mm_blend_epi8(low_half, high_half, 0xaa); -} - -/** - * @brief Emulate per-byte right shifts on 128-bit vectors. - */ -__always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0x00ff); - const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); - const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); - return mm_blend_epi8(low_half, high_half, 0xaa); -} - -/** - * @brief Emulate per-byte left shifts on 256-bit vectors. - */ -__always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0xff00); - const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); - const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); - return mm256_blend_epi8(low_half, high_half, 0xaa); -} - -/** - * @brief Emulate per-byte right shifts on 256-bit vectors. - */ -__always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0x00ff); - const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); - const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); - return mm256_blend_epi8(low_half, high_half, 0xaa); -} -#endif - -/** - * @brief Pack four 32-bit values for bit widths 1 through 3. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) -__always_inline auto mm_pack_epi32_avx2_1to3(__m128i& input) -> __m128i { - constexpr uint32_t bitmask = (1U << BIT_WIDTH) - 1U; - const __m128i masked = _mm_and_si128(input, _mm_set1_epi32(static_cast(bitmask))); - - alignas(16) uint32_t lanes[4]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(lanes), masked); - - const uint32_t packed = (lanes[0] & bitmask) | ((lanes[1] & bitmask) << BIT_WIDTH) | ((lanes[2] & bitmask) << (2 * BIT_WIDTH)) | - ((lanes[3] & bitmask) << (3 * BIT_WIDTH)); - - return _mm_cvtsi32_si128(static_cast(packed)); -} - -/** - * @brief Pack eight 32-bit values for bit widths 1 through 3. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) -__always_inline __m256i mm256_pack_epi32_avx2_1to3(const __m256i& input) { - constexpr uint32_t bitmask = (1u << BIT_WIDTH) - 1u; - - const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(static_cast(bitmask))); - - const __m256i shifts = _mm256_setr_epi32(0 * BIT_WIDTH, 1 * BIT_WIDTH, 2 * BIT_WIDTH, 3 * BIT_WIDTH, 4 * BIT_WIDTH, 5 * BIT_WIDTH, - 6 * BIT_WIDTH, 7 * BIT_WIDTH); - - const __m256i shifted = _mm256_sllv_epi32(masked, shifts); - - __m128i x = _mm_or_si128(_mm256_castsi256_si128(shifted), _mm256_extracti128_si256(shifted, 1)); - - x = _mm_or_si128(x, _mm_srli_si128(x, 8)); - x = _mm_or_si128(x, _mm_srli_si128(x, 4)); - - return _mm256_castsi128_si256(x); -} - -__always_inline __m256i mm256_pack_epi32_avx2_4(const __m256i& input) { - const __m256i zero = _mm256_setzero_si256(); - - const __m256i packed16 = _mm256_packus_epi32(input, zero); - const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); - const __m256i packed8 = _mm256_packus_epi16(permuted, zero); - - const __m256i combined = _mm256_or_si256(packed8, _mm256_srli_epi16(packed8, 4)); - - const __m256i shuffled = _mm256_shuffle_epi8(combined, _mm256_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, 0, 2, - 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1)); - - return shuffled; -} - -/** - * @brief Pack four 32-bit values for bit widths 9 through 16. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline auto mm_pack_epi32_avx2_9to16(__m128i& input) -> __m128i { - using tables = pack_tables_avx2_16; - constexpr uint16_t bitmask = (1 << BIT_WIDTH) - 1; - const __m128i masked = _mm_and_si128(input, _mm_set1_epi16(bitmask)); - __m128i combined; - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); - - combined = _mm_or_si128(shifted1, shifted2); - } else { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); - - combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); - } - return combined; -} - -/** - * @brief Pack eight 32-bit values for bit widths 9 through 16. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline auto mm256_pack_epi32_avx2_9to16(const __m256i& input) -> __m256i { - using tables = pack_tables_avx2_16; - constexpr uint16_t bitmask = (1 << BIT_WIDTH) - 1; - const __m128i packed = _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1)); - const __m128i masked = _mm_and_si128(packed, _mm_set1_epi16(bitmask)); - __m128i combined; - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15 || BIT_WIDTH == 16) { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); - - combined = _mm_or_si128(shifted1, shifted2); - } else { - const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); - const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); - const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); - - const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); - - combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); - } - return _mm256_castsi128_si256(combined); -} - -template - requires(BIT_WIDTH >= 5 && BIT_WIDTH <= 7) -__always_inline auto mm256_pack_epi32_avx2_5to7(const __m256i& input) -> __m256i { - const __m256i zero = _mm256_setzero_si256(); - - const __m256i packed16 = _mm256_packus_epi32(input, zero); - const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); - const __m256i packed8 = _mm256_packus_epi16(permuted, zero); - - const __m256i even = _mm256_and_si256(packed8, _mm256_set1_epi16(0x00FF)); - const __m256i odd = _mm256_and_si256(packed8, _mm256_set1_epi16(0xFF00)); - - const __m256i pair16 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); - const __m256i extended = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(pair16)); - - return mm256_pack_epi32_avx2_9to16<2 * BIT_WIDTH>(extended); -} - -/** - * @brief Pack eight 32-bit values for bit widths 17 through 24. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline auto mm256_pack_epi32_avx2_17to24(const __m256i& input) -> __m256i { - using tables = pack_tables_avx2_24; - constexpr uint32_t bitmask = (1 << BIT_WIDTH) - 1; - const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(bitmask)); - __m256i combined; - - if constexpr (BIT_WIDTH == 24) { - const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); - const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); - - const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); - const __m256i shifted2 = _mm256_srlv_epi32(shuffled2, tables::get_shift2()); - - combined = _mm256_or_si256(shifted1, shifted2); - } else { - const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); - const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); - const __m256i shuffled3 = _mm256_permutevar8x32_epi32(masked, tables::get_permute3()); - - const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi32(shuffled2, tables::get_shift2()); - const __m256i shifted3 = _mm256_srlv_epi32(shuffled3, tables::get_shift3()); - - combined = _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); - } - - return combined; -} - -/** - * @brief Pack aligned 8-bit or 16-bit values from four 32-bit lanes. - */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -auto mm_pack_aligned_epi32_avx2(__m128i& input) -> __m128i { - if constexpr (BIT_WIDTH == 8) { - return _mm_packus_epi16(_mm_packs_epi32(input, _mm_setzero_si128()), _mm_setzero_si128()); - } else { - return _mm_packs_epi32(input, _mm_setzero_si128()); - } -} - -/** - * @brief Dispatch to the appropriate 128-bit AVX2 packer for the selected bit width. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 16) -auto mm_pack_epi32_avx2(__m128i& input) -> __m128i { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { - return internal::mm_pack_epi32_avx2_1to3(input); - } else if constexpr (BIT_WIDTH >= 4 && BIT_WIDTH <= 8) { - // TODO: implementation for 4-8 bits - return _mm_setzero_si128(); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::mm_pack_epi32_avx2_9to16(input); - } else { - return _mm_setzero_si128(); - } -} - -/** - * @brief Pack aligned 8-bit or 16-bit values from eight 32-bit lanes. - */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -__m256i mm256_pack_aligned_epi32_avx2(const __m256i& input) { - if constexpr (BIT_WIDTH == 8) { - const __m128i packed16 = _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1)); - const __m128i packed8 = _mm_packs_epi16(packed16, _mm_setzero_si128()); - return _mm256_castsi128_si256(packed8); - } else { - return _mm256_castsi128_si256(_mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1))); - } -} - -/** - * @brief Dispatch to the appropriate 256-bit AVX2 packer for the selected bit width. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m256i mm256_pack_epi32_avx2(const __m256i& input) { - if constexpr (BIT_WIDTH == 8 || BIT_WIDTH == 16) { - return internal::mm256_pack_aligned_epi32_avx2(input); - } else { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { - return internal::mm256_pack_epi32_avx2_1to3(input); - } else if constexpr (BIT_WIDTH == 4) { - return mm256_pack_epi32_avx2_4(input); - } else if constexpr (BIT_WIDTH >= 5 && BIT_WIDTH <= 7) { - return internal::mm256_pack_epi32_avx2_5to7(input); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 15) { - return internal::mm256_pack_epi32_avx2_9to16(input); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::mm256_pack_epi32_avx2_17to24(input); - } - } - return _mm256_setzero_si256(); -} -} // namespace internal - -/** - * @brief Compress a single block of float using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_avx2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - std::memset(output, 0, BLOCK_SIZE); - - const __m256 scale_v = _mm256_set1_ps(scale); -#pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); - const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); - const __m256i packed = internal::mm256_pack_epi32_avx2(packed_input); - std::memcpy(output, &packed, BIT_WIDTH); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining) { - std::vector block_values(remaining); -#pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_ps_epi32(input[i], scale))); - } - - internal::pack_epi32_fallback(block_values, output); - } - - return 0; -} - -/** - * @brief Compress a single block of double using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_avx2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - std::memset(output, 0, BLOCK_SIZE); - - const __m256d scale_v = _mm256_set1_pd(scale); -#pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256d source1 = _mm256_loadu_pd(input); - const __m256d source2 = _mm256_loadu_pd(input + 4); - const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); - const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); - __m256i combined = _mm256_castsi128_si256(quantized1); - combined = _mm256_inserti128_si256(combined, quantized2, 1); - const __m256i packed = internal::mm256_pack_epi32_avx2(internal::mm256_clamp_signed_epi32(combined)); - // _mm_storeu_si128(reinterpret_cast<__m128i*>(output), _mm256_castsi256_si128(packed)); - std::memcpy(output, &packed, BIT_WIDTH); - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining) { - std::vector block_values(remaining); -#pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_pd_epi64(input[i], scale))); - } - - internal::pack_epi32_fallback(block_values, output); - } - - return 0; -} - -/** - * @brief Compress multiple blocks using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_avx2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_avx2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} - -/** - * @brief Compress multiple blocks using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_avx2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_avx2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_compress_block_avx2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_compress_block_f64_avx2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_avx2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_f64_avx2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_AVX2_COMPRESSION_H diff --git a/include/pernix/x86/avx2/decompression.h b/include/pernix/x86/avx2/decompression.h deleted file mode 100644 index 93cbce5..0000000 --- a/include/pernix/x86/avx2/decompression.h +++ /dev/null @@ -1,392 +0,0 @@ -#ifndef PERNIX_AVX2_DECOMPRESSION_H -#define PERNIX_AVX2_DECOMPRESSION_H - -#include -#include -#include - -#include -#include -#include -#include - -namespace pernix { -namespace internal { -/** - * @brief Convert an 8-lane mask to the lane representation used by AVX2 float masked stores. - */ -__always_inline __m256i mm256_convert_vmask_epi32(const __mmask8 mask8) { - return _mm256_setr_epi32((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, (mask8 & 0x8) ? -1 : 0, - (mask8 & 0x10) ? -1 : 0, (mask8 & 0x20) ? -1 : 0, (mask8 & 0x40) ? -1 : 0, (mask8 & 0x80) ? -1 : 0); -} - -/** - * @brief Convert a 4-lane mask to the lane representation used by AVX2 double masked stores. - */ -__always_inline __m256i mm256_convert_vmask_epi64(const __mmask8 mask8) { - return _mm256_setr_epi64x((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, (mask8 & 0x8) ? -1 : 0); -} - -/** - * @brief Dequantize four 32-bit integers to floats. - */ -__always_inline __m128 mm_dequantize_epi32(const __m128i& input, const __m128& scale) { - const __m128 converted = _mm_cvtepi32_ps(input); - return _mm_mul_ps(converted, scale); -} - -/* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ -__always_inline __m128d convert_epi64_pd(const __m128i v) { - __m128i xH = _mm_srai_epi32(v, 16); - xH = _mm_blend_epi16(xH, _mm_setzero_si128(), 0x33); - xH = _mm_add_epi64(xH, _mm_castpd_si128(_mm_set1_pd(442721857769029238784.))); // 3*2^67 - const __m128i xL = _mm_blend_epi16(v, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0x88); // 2^52 - const __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(442726361368656609280.)); // 3*2^67 + 2^52 - return _mm_add_pd(f, _mm_castsi128_pd(xL)); -} - -/** - * @brief Dequantize two 64-bit integers to doubles. - */ -__always_inline __m128d mm_dequantize_epi64_pd(const __m128i& input, const __m128d& scale) { - const __m128d converted = convert_epi64_pd(input); - return _mm_mul_pd(converted, scale); -} - -/** - * @brief Dequantize eight 32-bit integers to floats. - */ -__always_inline __m256 mm256_dequantize_epi32(const __m256i& input, const __m256& scale) { - const __m256 converted = _mm256_cvtepi32_ps(input); - return _mm256_mul_ps(converted, scale); -} - -/* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ -__always_inline __m256d convert_epi64_pd(__m256i v) { - __m256i xH = _mm256_srai_epi32(v, 16); - xH = _mm256_blend_epi16(xH, _mm256_setzero_si256(), 0x33); - xH = _mm256_add_epi64(xH, _mm256_castpd_si256(_mm256_set1_pd(442721857769029238784.))); - const __m256i xL = _mm256_blend_epi16(v, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0x88); - const __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(442726361368656609280.)); - return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); -} - -/** - * @brief Dequantize four 64-bit integers to doubles. - */ -__always_inline __m256d mm256_dequantize_epi64_pd(const __m256i& input, const __m256d& scale) { - const __m256d converted = convert_epi64_pd(input); - return _mm256_mul_pd(converted, scale); -} - -/** - * @brief Unpack four aligned 8-bit or 16-bit values directly from the input buffer. - */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -__m128i mm_unpack_aligned_epi32_avx2(const uint8_t* __restrict__ input) { - if constexpr (BIT_WIDTH == 8) { - const __m128i source = _mm_loadu_si32(input); - if constexpr (SIGN_VALUES) { - return _mm_cvtepi8_epi32(source); - } else { - return _mm_cvtepu8_epi32(source); - } - } else { - const __m128i source = _mm_loadu_si64(input); - if constexpr (SIGN_VALUES) { - return _mm_cvtepi16_epi32(source); - } else { - return _mm_cvtepu16_epi32(source); - } - } -} - -/** - * @brief Unpack four values using the table-driven AVX2 shuffle path. - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -__m128i mm_unpack_epi32_avx2(const uint8_t* __restrict__ input) { - using unpack_table = unpack_tables_avx2; - constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; - - __m128i source = _mm_setzero_si128(); - std::memcpy(&source, input, packed_bytes); - - const __m128i shuffled = _mm_shuffle_epi8(source, unpack_table::get_shuffle()); - - constexpr uint16_t shift = 32 - BIT_WIDTH; - __m128i shifted = _mm_sllv_epi32(shuffled, unpack_table::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm_srai_epi32(shifted, shift); - } else { - shifted = _mm_srli_epi32(shifted, shift); - } - - return shifted; -} - -/** - * @brief Unpack eight aligned 8-bit or 16-bit values directly from the input buffer. - */ -template - requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) -__m256i mm256_unpack_aligned_epi32_avx2(const uint8_t* __restrict__ input) { - if constexpr (BIT_WIDTH == 8) { - const __m128i source = _mm_loadu_si64(input); - if constexpr (SIGN_VALUES) { - return _mm256_cvtepi8_epi32(source); - } else { - return _mm256_cvtepu8_epi32(source); - } - } else { - const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); - if constexpr (SIGN_VALUES) { - return _mm256_cvtepi16_epi32(source); - } else { - return _mm256_cvtepu16_epi32(source); - } - } -} - -/** - * @brief Unpack eight values using the table-driven AVX2 shuffle path. - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -__m256i mm256_unpack_epi32_avx2(const uint8_t* __restrict__ input) { - using unpack_table = unpack_tables_avx2; - constexpr std::size_t packed_bytes = BIT_WIDTH; - - __m256i source = _mm256_setzero_si256(); - std::memcpy(&source, input, packed_bytes); - - const __m256i permuted = _mm256_permutevar8x32_epi32(source, unpack_table::get_permute()); - const __m256i shuffled = _mm256_shuffle_epi8(permuted, unpack_table::get_shuffle()); - - constexpr uint16_t shift = 32 - BIT_WIDTH; - __m256i shifted = _mm256_sllv_epi32(shuffled, unpack_table::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm256_srai_epi32(shifted, shift); - } else { - shifted = _mm256_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace internal - -/** - * @brief Decompress a single block to float using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_avx2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - const __m256 scale_v = _mm256_set1_ps(scale); -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); - const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); - _mm256_storeu_ps(output, dequantized); - input += BIT_WIDTH; - output += 8; - } - - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi32(tail_values[i], scale); - } - } - - return 0; -} - -/** - * @brief Decompress a single block to double using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_avx2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - const __m256d scale_v = _mm256_set1_pd(scale); -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); - const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); - const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); - - const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); - const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); - - _mm256_storeu_pd(output, dequantized1); - _mm256_storeu_pd(output + 4, dequantized2); - - input += BIT_WIDTH; - output += 8; - } - - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi64(tail_values[i], scale); - } - } - return 0; -} - -/** - * @brief Decompress multiple blocks to float using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_avx2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_avx2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -/** - * @brief Decompress multiple blocks to double using AVX2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_avx2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_avx2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Decompress a single 512-bit block to float using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_block_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block to double using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_block_f64_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_blocks_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks to double using AVX2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 support. - */ -int mm256_decompress_blocks_f64_avx2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_AVX2_DECOMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/compat.h b/include/pernix/x86/avx512vbmi/compat.h deleted file mode 100644 index bbc7ad4..0000000 --- a/include/pernix/x86/avx512vbmi/compat.h +++ /dev/null @@ -1,387 +0,0 @@ -#ifndef PERNIX_AVX512_COMPAT_H -#define PERNIX_AVX512_COMPAT_H - -#include -#include -#include -#include - -namespace pernix::internal { -static __always_inline __mmask8 element_mask8(const uint32_t e) { - return static_cast<__mmask8>(e >= 8 ? 0xFFu : ((1u << e) - 1u)); -} - -static __always_inline __mmask16 element_mask16(const uint32_t e) { - return static_cast<__mmask16>(e >= 16 ? 0xFFFFu : ((1u << e) - 1u)); -} - -static __always_inline __mmask32 element_mask32(const uint32_t e) { - return e >= 32 ? 0xFFFFFFFFu : (1u << e) - 1u; -} - -static __always_inline __mmask64 element_mask64(const uint32_t e) { - return e >= 64 ? 0xFFFFFFFFFFFFFFFFull : (1ull << e) - 1ull; -} - -static __always_inline __m512i mm512_loadu_elements_epi64(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); - return a; -#else - return _mm512_maskz_loadu_epi64(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi64(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); - return a; -#else - return _mm256_maskz_loadu_epi64(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi64(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int64_t)); - return a; -#else - return _mm_maskz_loadu_epi64(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512i mm512_loadu_elements_epi32(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); - return a; -#else - return _mm512_maskz_loadu_epi32(element_mask16(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi32(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); - return a; -#else - return _mm256_maskz_loadu_epi32(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi32(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int32_t)); - return a; -#else - return _mm_maskz_loadu_epi32(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512i mm512_loadu_elements_epi16(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); - return a; -#else - return _mm512_maskz_loadu_epi16(element_mask32(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi16(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); - return a; -#else - return _mm256_maskz_loadu_epi16(element_mask16(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi16(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int16_t)); - return a; -#else - return _mm_maskz_loadu_epi16(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512i mm512_loadu_elements_epi8(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512i a = _mm512_setzero_si512(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); - return a; -#else - return _mm512_maskz_loadu_epi8(element_mask64(e), mem_addr); -#endif -} - -static __always_inline __m256i mm256_loadu_elements_epi8(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256i a = _mm256_setzero_si256(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); - return a; -#else - return _mm256_maskz_loadu_epi8(element_mask32(e), mem_addr); -#endif -} - -static __always_inline __m128i mm_loadu_elements_epi8(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128i a = _mm_setzero_si128(); - std::memcpy(&a, mem_addr, e * sizeof(int8_t)); - return a; -#else - return _mm_maskz_loadu_epi8(element_mask16(e), mem_addr); -#endif -} - -static __always_inline void mm512_storeu_elements_epi64(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); -#else - _mm512_mask_storeu_epi64(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi64(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); -#else - _mm256_mask_storeu_epi64(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi64(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int64_t)); -#else - _mm_mask_storeu_epi64(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_epi32(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); -#else - _mm512_mask_storeu_epi32(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi32(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); -#else - _mm256_mask_storeu_epi32(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi32(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int32_t)); -#else - _mm_mask_storeu_epi32(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_epi16(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); -#else - _mm512_mask_storeu_epi16(mem_addr, element_mask32(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi16(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); -#else - _mm256_mask_storeu_epi16(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi16(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int16_t)); -#else - _mm_mask_storeu_epi16(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_epi8(void* mem_addr, const uint32_t e, const __m512i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) uint8_t bytes[64]; - _mm512_storeu_si512(bytes, a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); -#else - _mm512_mask_storeu_epi8(mem_addr, element_mask64(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_epi8(void* mem_addr, const uint32_t e, const __m256i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) uint8_t bytes[32]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); -#else - _mm256_mask_storeu_epi8(mem_addr, element_mask32(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_epi8(void* mem_addr, const uint32_t e, const __m128i a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) uint8_t bytes[16]; - _mm_storeu_si128(reinterpret_cast<__m128i*>(bytes), a); - std::memcpy(mem_addr, bytes, e * sizeof(int8_t)); -#else - _mm_mask_storeu_epi8(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline __m512 mm512_loadu_elements_ps(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512 a = _mm512_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); - return a; -#else - return _mm512_maskz_loadu_ps(element_mask16(e), mem_addr); -#endif -} - -static __always_inline __m256 mm256_loadu_elements_ps(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256 a = _mm256_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); - return a; -#else - return _mm256_maskz_loadu_ps(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128 mm_loadu_elements_ps(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128 a = _mm_setzero_ps(); - std::memcpy(&a, mem_addr, e * sizeof(float_t)); - return a; -#else - return _mm_maskz_loadu_ps(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m512d mm512_loadu_elements_pd(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m512d a = _mm512_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); - return a; -#else - return _mm512_maskz_loadu_pd(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m256d mm256_loadu_elements_pd(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m256d a = _mm256_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); - return a; -#else - return _mm256_maskz_loadu_pd(element_mask8(e), mem_addr); -#endif -} - -static __always_inline __m128d mm_loadu_elements_pd(const uint32_t e, const void* mem_addr) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - __m128d a = _mm_setzero_pd(); - std::memcpy(&a, mem_addr, e * sizeof(double_t)); - return a; -#else - return _mm_maskz_loadu_pd(element_mask8(e), mem_addr); -#endif -} - -static __always_inline void mm512_storeu_elements_ps(void* mem_addr, const uint32_t e, const __m512 a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) float_t values[16]; - _mm512_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); -#else - _mm512_mask_storeu_ps(mem_addr, element_mask16(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_ps(void* mem_addr, const uint32_t e, const __m256 a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) float_t values[8]; - _mm256_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); -#else - _mm256_mask_storeu_ps(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_ps(void* mem_addr, const uint32_t e, const __m128 a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) float_t values[4]; - _mm_storeu_ps(values, a); - std::memcpy(mem_addr, values, e * sizeof(float_t)); -#else - _mm_mask_storeu_ps(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm512_storeu_elements_pd(void* mem_addr, const uint32_t e, const __m512d a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(64) double_t values[8]; - _mm512_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); -#else - _mm512_mask_storeu_pd(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm256_storeu_elements_pd(void* mem_addr, const uint32_t e, const __m256d a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(32) double_t values[4]; - _mm256_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); -#else - _mm256_mask_storeu_pd(mem_addr, element_mask8(e), a); -#endif -} - -static __always_inline void mm_storeu_elements_pd(void* mem_addr, const uint32_t e, const __m128d a) { -#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) - alignas(16) double_t values[2]; - _mm_storeu_pd(values, a); - std::memcpy(mem_addr, values, e * sizeof(double_t)); -#else - _mm_mask_storeu_pd(mem_addr, element_mask8(e), a); -#endif -} -} - -#endif //PERNIX_AVX512_COMPAT_H diff --git a/include/pernix/x86/avx512vbmi/compression.h b/include/pernix/x86/avx512vbmi/compression.h deleted file mode 100644 index bc5a375..0000000 --- a/include/pernix/x86/avx512vbmi/compression.h +++ /dev/null @@ -1,722 +0,0 @@ -#ifndef PERNIX_AVX512VBMI_COMPRESSION_H -#define PERNIX_AVX512VBMI_COMPRESSION_H - -#include -#include -#include -#include - -#include - -namespace pernix { -namespace internal { -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -static __always_inline __m512i mm512_clamp_signed_epi32(__m512i input) { - constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm512_min_epi32(_mm512_max_epi32(input, _mm512_set1_epi32(min_value)), _mm512_set1_epi32(max_value)); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -static __always_inline __m256i mm256_clamp_signed_epi32_avx512(__m256i input) { - constexpr int32_t min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); - constexpr int32_t max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); - return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), _mm256_set1_epi32(max_value)); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -static __always_inline __m512i mm512_clamp_signed_epi64(__m512i input) { - constexpr int64_t min_value = BIT_WIDTH == 1 ? 0 : -(int64_t{1} << (BIT_WIDTH - 1)); - constexpr int64_t max_value = BIT_WIDTH == 1 ? 1 : ((int64_t{1} << (BIT_WIDTH - 1)) - 1); - return _mm512_min_epi64(_mm512_max_epi64(input, _mm512_set1_epi64(min_value)), _mm512_set1_epi64(max_value)); -} - -/** - * @brief Quantize sixteen float values to 32-bit integers. - */ - -static __always_inline __m512i mm512_quantize_ps_epi32(const __m512& input, const __m512& scale) { - const __m512 scaled = _mm512_mul_ps(input, scale); - return _mm512_cvtps_epi32(scaled); -} - -static __always_inline __m512i mm512_quantize_pd_epi64(const __m512d& input, const __m512d& scale) { - const __m512d scaled = _mm512_mul_pd(input, scale); - return _mm512_cvtpd_epi64(scaled); -} - -static __always_inline __m256i mm512_quantize_pd_epi32(const __m512d& input, const __m512d& scale) { - const __m512d scaled = _mm512_mul_pd(input, scale); - return _mm512_cvtpd_epi32(scaled); -} - -static __always_inline __m512i make_m512i_from_2x256(const __m256i a, const __m256i b) { - __m512i result = _mm512_castsi256_si512(a); - result = _mm512_inserti64x4(result, b, 1); - return result; -} - -static __always_inline __m512i make_m512i_from_4x128(const __m128i a, const __m128i b, const __m128i c, const __m128i d) { - __m512i result = _mm512_castsi128_si512(a); - result = _mm512_inserti64x2(result, b, 1); - result = _mm512_inserti64x2(result, c, 2); - result = _mm512_inserti64x2(result, d, 3); - return result; -} - -static __always_inline __m512i make_m512i_from_8x64(const __m128i a, const __m128i b, const __m128i c, const __m128i d, const __m128i e, - const __m128i f, const __m128i g, const __m128i h) { - const __m128i ab = _mm_unpacklo_epi64(a, b); - const __m128i cd = _mm_unpacklo_epi64(c, d); - const __m128i ef = _mm_unpacklo_epi64(e, f); - const __m128i gh = _mm_unpacklo_epi64(g, h); - - __m512i x = _mm512_castsi128_si512(ab); - x = _mm512_inserti32x4(x, cd, 1); - x = _mm512_inserti32x4(x, ef, 2); - x = _mm512_inserti32x4(x, gh, 3); - return x; -} - -static __always_inline __m256i make_m256i_from_2x128(const __m128i a, const __m128i b) { - __m256i result = _mm256_castsi128_si256(a); - result = _mm256_inserti128_si256(result, b, 1); - return result; -} - -static __always_inline __m256i make_m256i_from_4x64(const __m128i a, const __m128i b, const __m128i c, const __m128i d) { - const __m128i ab = _mm_unpacklo_epi64(a, b); - const __m128i cd = _mm_unpacklo_epi64(c, d); - - __m256i x = _mm256_castsi128_si256(ab); - x = _mm256_inserti128_si256(x, cd, 1); - return x; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_64 = elements_per_block / 64; - constexpr uint32_t iterations_32 = (elements_per_block % 64) / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512 scale_v = _mm512_set1_ps(scale); - - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_64; ++iter) { - const __m512 source1 = _mm512_loadu_ps(input); - const __m512 source2 = _mm512_loadu_ps(input + 16); - const __m512 source3 = _mm512_loadu_ps(input + 32); - const __m512 source4 = _mm512_loadu_ps(input + 48); - - const __m512i quantized1 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source4, scale_v)); - - const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); - const __m128i converted3 = _mm512_cvtepi32_epi8(quantized3); - const __m128i converted4 = _mm512_cvtepi32_epi8(quantized4); - - const __m512i packed = - m512::mm512_pack_epi8_avx512vbmi_1to8(make_m512i_from_4x128(converted1, converted2, converted3, converted4)); - - mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - - input += 64; - output += 8 * BIT_WIDTH; - } - } - - if constexpr (iterations_32 > 0) { - const __m512 source1 = _mm512_loadu_ps(input); - const __m512 source2 = _mm512_loadu_ps(input + 16); - - const __m512i quantized1 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source2, scale_v)); - - const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); - - const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_2x128(converted1, converted2)); - mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; - } - - if constexpr (iterations_16 > 0) { - const __m512 source = _mm512_loadu_ps(input); - const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - const __m128i converted = _mm512_cvtepi32_epi8(quantized); - - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); - mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m512 source = mm512_loadu_elements_ps(remaining_elements, input); - const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - const __m128i converted = _mm512_cvtepi32_epi8(quantized); - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); - - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); - - if constexpr (iterations_32 > 0) { -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_32; ++iter) { - const __m512 source1 = _mm512_loadu_ps(input); - const __m512 source2 = _mm512_loadu_ps(input + 16); - - const __m512i quantized1 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source2, scale_v)); - - const __m256i converted1 = _mm512_cvtepi32_epi16(quantized1); - const __m256i converted2 = _mm512_cvtepi32_epi16(quantized2); - - const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_2x256(converted1, converted2)); - mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; - } - } - - if constexpr (iterations_16 > 0) { - const __m512 source = _mm512_loadu_ps(input); - const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - const __m256i converted = _mm512_cvtepi32_epi16(quantized); - - const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(converted); - mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - - if constexpr (iterations_8 > 0) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m128i converted = _mm256_cvtepi32_epi16(quantized); - - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m128i converted = _mm256_cvtepi32_epi16(quantized); - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); - - if constexpr (iterations_16 > 0) { -#pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512 source = _mm512_loadu_ps(input); - const __m512i packed_input = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); - - const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(packed_input); - mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; - output += 2 * BIT_WIDTH; - } - } - - if constexpr (iterations_8 > 0) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); - mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); - const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); - - mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_64 = elements_per_block / 64; - constexpr uint32_t iterations_32 = (elements_per_block % 64) / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t iter = 0; iter < iterations_64; ++iter) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512d source3 = _mm512_loadu_pd(input + 16); - const __m512d source4 = _mm512_loadu_pd(input + 24); - const __m512d source5 = _mm512_loadu_pd(input + 32); - const __m512d source6 = _mm512_loadu_pd(input + 40); - const __m512d source7 = _mm512_loadu_pd(input + 48); - const __m512d source8 = _mm512_loadu_pd(input + 56); - - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source4, scale_v)); - const __m512i quantized5 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source5, scale_v)); - const __m512i quantized6 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source6, scale_v)); - const __m512i quantized7 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source7, scale_v)); - const __m512i quantized8 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source8, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); - const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); - const __m128i converted5 = _mm512_cvtepi64_epi8(quantized5); - const __m128i converted6 = _mm512_cvtepi64_epi8(quantized6); - const __m128i converted7 = _mm512_cvtepi64_epi8(quantized7); - const __m128i converted8 = _mm512_cvtepi64_epi8(quantized8); - - const __m512i packed = m512::mm512_pack_epi8_avx512vbmi_1to8( - make_m512i_from_8x64(converted1, converted2, converted3, converted4, converted5, converted6, converted7, converted8)); - - mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - - input += 64; - output += 8 * BIT_WIDTH; - } - } - - if constexpr (iterations_32 > 0) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512d source3 = _mm512_loadu_pd(input + 16); - const __m512d source4 = _mm512_loadu_pd(input + 24); - - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source4, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); - const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); - - const __m256i packed = - m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_4x64(converted1, converted2, converted3, converted4)); - - mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; - } - - if constexpr (iterations_16 > 0) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(_mm_unpacklo_epi64(converted1, converted2)); - - mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - constexpr uint32_t source1_elements = remaining_elements > 8 ? 8 : remaining_elements; - constexpr uint32_t source2_elements = remaining_elements > 8 ? remaining_elements - 8 : 0; - - const __m512d source1 = mm512_loadu_elements_pd(source1_elements, input); - const __m512d source2 = source2_elements > 0 ? mm512_loadu_elements_pd(source2_elements, input + 8) : _mm512_setzero_pd(); - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); - - const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(_mm_unpacklo_epi64(converted1, converted2)); - - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_32 > 0) { -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_32; ++iter) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512d source3 = _mm512_loadu_pd(input + 16); - const __m512d source4 = _mm512_loadu_pd(input + 24); - - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - const __m512i quantized3 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source3, scale_v)); - const __m512i quantized4 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source4, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); - const __m128i converted3 = _mm512_cvtepi64_epi16(quantized3); - const __m128i converted4 = _mm512_cvtepi64_epi16(quantized4); - - const __m512i packed = - m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_4x128(converted1, converted2, converted3, converted4)); - - mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - - input += 32; - output += 4 * BIT_WIDTH; - } - } - - if constexpr (iterations_16 > 0) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - const __m512i quantized1 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source1, scale_v)); - const __m512i quantized2 = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source2, scale_v)); - - const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); - const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); - - const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(make_m256i_from_2x128(converted1, converted2)); - - mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - - if constexpr (iterations_8 > 0) { - const __m512d source = _mm512_loadu_pd(input); - const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); - const __m128i converted = _mm512_cvtepi64_epi16(quantized); - - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); - const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); - const __m128i converted = _mm512_cvtepi64_epi16(quantized); - - const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); - mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_16 > 0) { -#pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512d source1 = _mm512_loadu_pd(input); - const __m512d source2 = _mm512_loadu_pd(input + 8); - - const __m256i quantized1 = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source1, scale_v)); - const __m256i quantized2 = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source2, scale_v)); - - const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(make_m512i_from_2x256(quantized1, quantized2)); - mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - - input += 16; - output += 2 * BIT_WIDTH; - } - } - - if constexpr (iterations_8 > 0) { - const __m512d source = _mm512_loadu_pd(input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source, scale_v)); - - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); - mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); - const __m256i quantized = mm256_clamp_signed_epi32_avx512(mm512_quantize_pd_epi32(source, scale_v)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); - - mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); - } - - return 0; -} -} // namespace internal - -/** - * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_block_avx512vbmi(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - std::memset(output, 0, BLOCK_SIZE); - - if constexpr (BIT_WIDTH <= 8) { - return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH <= 16) { - return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); - } else { - return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); - } -} - -/** - * @brief Compress a single block of double values using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code. - * - * @note This overload is declared for parity with the float path. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_block_avx512vbmi(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - std::memset(output, 0, BLOCK_SIZE); - - if constexpr (BIT_WIDTH <= 8) { - return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH <= 16) { - return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); - } else { - return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); - } -} - -/** - * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} - -/** - * @brief Compress multiple blocks of double values using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_block_avx512vbmi(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_block_f64_avx512vbmi(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, - uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_blocks_avx512vbmi(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_compress_blocks_f64_avx512vbmi(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, - uint8_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/include/pernix/x86/avx512vbmi/decompression.h deleted file mode 100644 index 61280c9..0000000 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ /dev/null @@ -1,672 +0,0 @@ -#ifndef PERNIX_AVX512VBMI_DECOMPRESSION_H -#define PERNIX_AVX512VBMI_DECOMPRESSION_H - -#include -#include -#include -#include - -#include - -namespace pernix { -namespace internal { -/** - * @brief Dequantize sixteen integer values to floats. - */ -[[gnu::always_inline]] inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { - const __m512 converted = _mm512_cvtepi32_ps(input); - return _mm512_mul_ps(converted, scale); -} - -[[gnu::always_inline]] inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { - const __m512d converted = _mm512_cvtepi64_pd(input); - return _mm512_mul_pd(converted, scale); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const uint32_t iterations_64 = elements_per_block / 64; - const uint32_t iterations_32 = (elements_per_block % 64) / 32; - const uint32_t iterations_16 = (elements_per_block % 32) / 16; - const uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512 scale_v = _mm512_set1_ps(scale); - - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t i = 0; i < iterations_64; ++i) { - const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted1 = _mm512_cvtepi8_epi32(_mm512_castsi512_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 1)); - const __m512i converted3 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 2)); - const __m512i converted4 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 3)); - - const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); - const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - const __m512 dequantized3 = mm512_dequantize_epi32(converted3, scale_v); - const __m512 dequantized4 = mm512_dequantize_epi32(converted4, scale_v); - - _mm512_storeu_ps(output, dequantized1); - _mm512_storeu_ps(output + 16, dequantized2); - _mm512_storeu_ps(output + 32, dequantized3); - _mm512_storeu_ps(output + 48, dequantized4); - - output += 64; - input += 8 * BIT_WIDTH; - } - } - - if constexpr (iterations_32 > 0) { - const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted1 = _mm512_cvtepi8_epi32(_mm256_castsi256_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi8_epi32(_mm256_extracti128_si256(unpacked, 1)); - - const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); - const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - - _mm512_storeu_ps(output, dequantized1); - _mm512_storeu_ps(output + 16, dequantized2); - - output += 32; - input += 4 * BIT_WIDTH; - } - - if constexpr (iterations_16 > 0) { - const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted = _mm512_cvtepi8_epi32(unpacked); - - const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - - _mm512_storeu_ps(output, dequantized); - - output += 16; - input += 2 * BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted = _mm512_cvtepi8_epi32(unpacked); - - const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - - mm512_storeu_elements_ps(output, remaining_elements, dequantized); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - const uint32_t iterations_64 = elements_per_block / 64; - const uint32_t iterations_32 = (elements_per_block % 64) / 32; - const uint32_t iterations_16 = (elements_per_block % 32) / 16; - const uint32_t remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_64 > 0) { -#pragma GCC unroll 8 - for (uint32_t i = 0; i < iterations_64; ++i) { - const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8(source); - - const __m128i extracted1 = _mm512_castsi512_si128(unpacked); - const __m128i extracted2 = _mm512_extracti64x2_epi64(unpacked, 1); - const __m128i extracted3 = _mm512_extracti64x2_epi64(unpacked, 2); - const __m128i extracted4 = _mm512_extracti64x2_epi64(unpacked, 3); - - const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); - const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); - const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); - const __m512i converted5 = _mm512_cvtepi8_epi64(extracted3); - const __m512i converted6 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted3, 8)); - const __m512i converted7 = _mm512_cvtepi8_epi64(extracted4); - const __m512i converted8 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted4, 8)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); - const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); - const __m512d dequantized5 = mm512_dequantize_epi64(converted5, scale_v); - const __m512d dequantized6 = mm512_dequantize_epi64(converted6, scale_v); - const __m512d dequantized7 = mm512_dequantize_epi64(converted7, scale_v); - const __m512d dequantized8 = mm512_dequantize_epi64(converted8, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - _mm512_storeu_pd(output + 16, dequantized3); - _mm512_storeu_pd(output + 24, dequantized4); - _mm512_storeu_pd(output + 32, dequantized5); - _mm512_storeu_pd(output + 40, dequantized6); - _mm512_storeu_pd(output + 48, dequantized7); - _mm512_storeu_pd(output + 56, dequantized8); - - output += 64; - input += 8 * BIT_WIDTH; - } - - if constexpr (iterations_32 > 0) { - const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8(source); - - const __m128i extracted1 = _mm256_castsi256_si128(unpacked); - const __m128i extracted2 = _mm256_extracti64x2_epi64(unpacked, 1); - - const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); - const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); - const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); - const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - _mm512_storeu_pd(output + 16, dequantized3); - _mm512_storeu_pd(output + 24, dequantized4); - - output += 32; - input += 4 * BIT_WIDTH; - } - - if constexpr (iterations_16 > 0) { - const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - - output += 16; - input += 2 * BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8(source); - - const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); - const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - - mm512_storeu_elements_pd(output, remaining_elements < 8 ? remaining_elements : 8, dequantized1); - if constexpr (remaining_elements > 8) { - mm512_storeu_elements_pd(output + 8, remaining_elements - 8, dequantized2); - } - } - } - - return 0; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); - - if constexpr (iterations_32 > 0) { -#pragma GCC unroll 4 - for (uint32_t i = 0; i < iterations_32; ++i) { - const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted1 = _mm512_cvtepi16_epi32(_mm512_castsi512_si256(unpacked)); - const __m512i converted2 = _mm512_cvtepi16_epi32(_mm512_extracti32x8_epi32(unpacked, 1)); - - const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); - const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); - - _mm512_storeu_ps(output, dequantized1); - _mm512_storeu_ps(output + 16, dequantized2); - - output += 32; - input += 4 * BIT_WIDTH; - } - } - - if constexpr (iterations_16 > 0) { - const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted = _mm512_cvtepi16_epi32(unpacked); - const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); - - _mm512_storeu_ps(output, dequantized); - - output += 16; - input += 2 * BIT_WIDTH; - } - - if constexpr (iterations_8 > 0) { - const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); - - const __m256i converted = _mm256_cvtepi16_epi32(unpacked); - const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); - - _mm256_storeu_ps(output, dequantized); - - output += 8; - input += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); - - const __m256i converted = _mm256_cvtepi16_epi32(unpacked); - const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); - - mm256_storeu_elements_ps(output, remaining_elements, dequantized); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_32 = elements_per_block / 32; - constexpr uint32_t iterations_16 = (elements_per_block % 32) / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_32 > 0) { -#pragma GCC unroll 4 - for (uint32_t i = 0; i < iterations_32; ++i) { - const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted1 = _mm512_cvtepi16_epi64(_mm512_castsi512_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 1)); - const __m512i converted3 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 2)); - const __m512i converted4 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 3)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); - const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - _mm512_storeu_pd(output + 16, dequantized3); - _mm512_storeu_pd(output + 24, dequantized4); - - output += 32; - input += 4 * BIT_WIDTH; - } - } - - if constexpr (iterations_16 > 0) { - const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted1 = _mm512_cvtepi16_epi64(_mm256_castsi256_si128(unpacked)); - const __m512i converted2 = _mm512_cvtepi16_epi64(_mm256_extracti64x2_epi64(unpacked, 1)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - - output += 16; - input += 2 * BIT_WIDTH; - } - - if constexpr (iterations_8 > 0) { - const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted = _mm512_cvtepi16_epi64(unpacked); - - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - - _mm512_storeu_pd(output, dequantized); - - output += 8; - input += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16(source); - - const __m512i converted = _mm512_cvtepi16_epi64(unpacked); - - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - - mm512_storeu_elements_pd(output, remaining_elements, dequantized); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - - const __m512 scale_v = _mm512_set1_ps(scale); - const __m256 scale_v256 = _mm256_set1_ps(scale); - - if constexpr (iterations_16 > 0) { -#pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24(source); - - const __m512 dequantized = mm512_dequantize_epi32(unpacked, scale_v); - - _mm512_storeu_ps(output, dequantized); - - output += 16; - input += 2 * BIT_WIDTH; - } - } - - if constexpr (iterations_8 > 0) { - const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); - - const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); - - _mm256_storeu_ps(output, dequantized); - - output += 8; - input += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); - - const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); - - mm256_storeu_elements_ps(output, remaining_elements, dequantized); - } - - return 0; -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - - constexpr uint32_t iterations_16 = elements_per_block / 16; - constexpr uint32_t iterations_8 = (elements_per_block % 16) / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; - - const __m512d scale_v = _mm512_set1_pd(scale); - - if constexpr (iterations_16 > 0) { -#pragma GCC unroll 2 - for (uint32_t i = 0; i < iterations_16; ++i) { - const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); - const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24(source); - - const __m512i converted1 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(unpacked)); - const __m512i converted2 = _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(unpacked, 1)); - - const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); - const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); - - _mm512_storeu_pd(output, dequantized1); - _mm512_storeu_pd(output + 8, dequantized2); - - output += 16; - input += 2 * BIT_WIDTH; - } - } - - if constexpr (iterations_8 > 0) { - const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); - - const __m512i converted = _mm512_cvtepi32_epi64(unpacked); - - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - - _mm512_storeu_pd(output, dequantized); - - output += 8; - input += BIT_WIDTH; - } - - if constexpr (remaining_elements > 0) { - const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); - const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24(source); - - const __m512i converted = _mm512_cvtepi32_epi64(unpacked); - - const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); - - mm512_storeu_elements_pd(output, remaining_elements, dequantized); - } - - return 0; -} -} // namespace internal - -/** - * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::mm512_decompress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::mm512_decompress_block_avx512vbmi_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::mm512_decompress_block_avx512vbmi_17to24(input, scale, output); - } - return 0; -} - -/** - * @brief Decompress a single block to double values using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int mm512_decompress_block_avx512vbmi(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - return internal::mm512_decompress_block_avx512vbmi_1to8(input, scale, output); - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - return internal::mm512_decompress_block_avx512vbmi_9to16(input, scale, output); - } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { - return internal::mm512_decompress_block_avx512vbmi_17to24(input, scale, output); - } - return 0; -} - -/** - * @brief Decompress multiple 512\-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -/** - * @brief Decompress multiple blocks to double values using AVX-512 and AVX-512-VBMI instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of blocks to decompress. - * @return int status code. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif -/** - * @brief Decompress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_block_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_block_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_blocks_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX-512 and AVX-512-VBMI support. - */ -int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/packing.h b/include/pernix/x86/avx512vbmi/packing.h deleted file mode 100644 index c9f9db9..0000000 --- a/include/pernix/x86/avx512vbmi/packing.h +++ /dev/null @@ -1,327 +0,0 @@ -#ifndef PERNIX_AVX512VBMI_PACKING_H -#define PERNIX_AVX512VBMI_PACKING_H - -#include -#include - -namespace pernix::internal { -namespace m128 { -/** - * @brief Pack 8 16-bit values for bit widths 9 through 16 using VBMI. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = pack_tables_avx512_16; - const __m128i maskv = _mm_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); - const __m128i masked = _mm_and_si128(input, maskv); - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m128i permuted1 = _mm_permutexvar_epi16(tables::get_permute1(), masked); - const __m128i permuted2 = _mm_permutexvar_epi16(tables::get_permute2(), masked); - - const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); - const __m128i shifted2 = _mm_srlv_epi16(permuted2, tables::get_shift2()); - - return _mm_or_si128(shifted1, shifted2); - } else { - const auto [mask1, mask2, mask3] = tables::get_permute_masks(); - - const __m128i permuted1 = _mm_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); - const __m128i permuted2 = _mm_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); - const __m128i permuted3 = _mm_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); - - const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi16(permuted3, tables::get_shift3()); - - return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); - } - } -} - -/** - * @brief Pack 16 8-bit values for bit widths 1 through 8 using VBMI. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - const __m128i maskv = _mm_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); - const __m128i masked = _mm_and_si128(input, maskv); - - if constexpr (BIT_WIDTH == 1) { - return _mm_set1_epi16(static_cast(_mm_cmpgt_epi8_mask(masked, _mm_setzero_si128()))); - } else if constexpr (BIT_WIDTH == 2) { - const __m128i shifted = _mm_srli_epi16(masked, 6); - const __m128i combined = _mm_or_si128(masked, shifted); - - const __m128i shifted2 = _mm_srli_epi32(combined, 12); - const __m128i combined2 = _mm_or_si128(shifted2, combined); - - return _mm_cvtepi32_epi8(combined2); - } else if constexpr (BIT_WIDTH == 3) { - const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); - const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); - - const __m128i pair6 = _mm_or_si128(even, _mm_srli_epi16(odd, 5)); - const __m128i packed12 = _mm_or_si128(pair6, _mm_srli_epi32(pair6, 10)); - - return m128::mm_pack_epi16_avx512vbmi_9to16<12>(_mm_cvtepi32_epi16(packed12)); - } else if constexpr (BIT_WIDTH == 4) { - const __m128i shifted = _mm_srli_epi16(masked, 4); - const __m128i combined = _mm_or_si128(masked, shifted); - - return _mm_cvtepi16_epi8(combined); - } else { - const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); - const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); - - const __m128i shifted = _mm_or_si128(even, _mm_srli_epi16(odd, 8 - BIT_WIDTH)); - return mm_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); - } - } -} - -/** - * @brief Pack 4 32-bit values for bit widths 17 through 24 using VBMI. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { - using tables = pack_tables_avx512_24; - - const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); - const __m128i masked = _mm_and_si128(input, maskv); - - const __m128 permuted1 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute1()); - const __m128 permuted2 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute2()); - const __m128 permuted3 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute3()); - - const __m128i shifted1 = _mm_sllv_epi32(_mm_castps_si128(permuted1), tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi32(_mm_castps_si128(permuted2), tables::get_shift2()); - const __m128i shifted3 = _mm_srlv_epi32(_mm_castps_si128(permuted3), tables::get_shift3()); - - return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); -} -} // namespace m128 - -namespace m256 { -/** - * @brief Pack 16 16-bit values for bit widths 9 through 16 using VBMI. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = pack_tables_avx512_16; - const __m256i maskv = _mm256_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); - const __m256i masked = _mm256_and_si256(input, maskv); - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m256i permuted1 = _mm256_permutexvar_epi16(tables::get_permute1(), masked); - const __m256i permuted2 = _mm256_permutexvar_epi16(tables::get_permute2(), masked); - - const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_srlv_epi16(permuted2, tables::get_shift2()); - - return _mm256_or_si256(shifted1, shifted2); - } else { - const auto [mask1, mask2, mask3] = tables::get_permute_masks(); - - const __m256i permuted1 = _mm256_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); - const __m256i permuted2 = _mm256_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); - const __m256i permuted3 = _mm256_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); - - const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); - const __m256i shifted3 = _mm256_srlv_epi16(permuted3, tables::get_shift3()); - - return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); - } - } -} - -/** - * @brief Pack 32 8-bit values for bit widths 1 through 8 using VBMI. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - const __m256i maskv = _mm256_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); - const __m256i masked = _mm256_and_si256(input, maskv); - - if constexpr (BIT_WIDTH == 1) { - return _mm256_set1_epi32(static_cast(_mm256_cmpgt_epi8_mask(masked, _mm256_setzero_si256()))); - } else if constexpr (BIT_WIDTH == 2) { - const __m256i shifted = _mm256_srli_epi16(masked, 6); - const __m256i combined = _mm256_or_si256(masked, shifted); - - const __m256i shifted2 = _mm256_srli_epi32(combined, 12); - const __m256i combined2 = _mm256_or_si256(shifted2, combined); - - return _mm256_castsi128_si256(_mm256_cvtepi32_epi8(combined2)); - } else if constexpr (BIT_WIDTH == 3) { - const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); - const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); - - const __m256i pair6 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 5)); - const __m256i packed12 = _mm256_or_si256(pair6, _mm256_srli_epi32(pair6, 10)); - - return m256::mm256_pack_epi16_avx512vbmi_9to16<12>(_mm256_castsi128_si256(_mm256_cvtepi32_epi16(packed12))); - } else if constexpr (BIT_WIDTH == 4) { - const __m256i shifted = _mm256_srli_epi16(masked, 4); - const __m256i combined = _mm256_or_si256(masked, shifted); - - return _mm256_castsi128_si256(_mm256_cvtepi16_epi8(combined)); - } else { - const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); - const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); - - const __m256i shifted = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); - return mm256_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); - } - } -} - -/** - * @brief Pack 8 32-bit values for bit widths 17 through 24 using VBMI. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { - using tables = pack_tables_avx512_24; - - const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); - const __m256i masked = _mm256_and_si256(input, maskv); - - const __m256i permuted1 = _mm256_permutexvar_epi32(tables::get_permute1(), masked); - const __m256i permuted2 = _mm256_permutexvar_epi32(tables::get_permute2(), masked); - const __m256i permuted3 = _mm256_permutexvar_epi32(tables::get_permute3(), masked); - - const __m256i shifted1 = _mm256_sllv_epi32(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi32(permuted2, tables::get_shift2()); - const __m256i shifted3 = _mm256_srlv_epi32(permuted3, tables::get_shift3()); - - return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); -} -} // namespace m256 - -namespace m512 { -/** - * @brief Pack 32 16-bit values for bit widths 9 through 16 using VBMI. - */ -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = pack_tables_avx512_16; - const __m512i maskv = _mm512_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); - const __m512i masked = _mm512_and_si512(input, maskv); - - if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m512i permuted1 = _mm512_permutexvar_epi16(tables::get_permute1(), masked); - const __m512i permuted2 = _mm512_permutexvar_epi16(tables::get_permute2(), masked); - - const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_srlv_epi16(permuted2, tables::get_shift2()); - - return _mm512_or_si512(shifted1, shifted2); - } else { - const auto [mask1, mask2, mask3] = tables::get_permute_masks(); - - const __m512i permuted1 = _mm512_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); - const __m512i permuted2 = _mm512_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); - const __m512i permuted3 = _mm512_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); - - const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); - const __m512i shifted3 = _mm512_srlv_epi16(permuted3, tables::get_shift3()); - - return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); - } - } -} - -/** - * @brief Pack 64 8-bit values for bit widths 1 through 8 using VBMI. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - const __m512i maskv = _mm512_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); - const __m512i masked = _mm512_and_si512(input, maskv); - - if constexpr (BIT_WIDTH == 1) { - return _mm512_set1_epi64(static_cast(_mm512_cmpgt_epi8_mask(masked, _mm512_setzero_si512()))); - } else if constexpr (BIT_WIDTH == 2) { - const __m512i shifted = _mm512_srli_epi16(masked, 6); - const __m512i combined = _mm512_or_si512(masked, shifted); - - const __m512i shifted2 = _mm512_srli_epi32(combined, 12); - const __m512i combined2 = _mm512_or_si512(shifted2, combined); - - return _mm512_castsi128_si512(_mm512_cvtepi32_epi8(combined2)); - } else if constexpr (BIT_WIDTH == 3) { - const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); - const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); - - const __m512i pair6 = _mm512_or_si512(even, _mm512_srli_epi16(odd, 5)); - const __m512i packed12 = _mm512_or_si512(pair6, _mm512_srli_epi32(pair6, 10)); - - return _mm512_castsi256_si512(m256::mm256_pack_epi16_avx512vbmi_9to16<12>(_mm512_cvtepi32_epi16(packed12))); - } else if constexpr (BIT_WIDTH == 4) { - const __m512i shifted = _mm512_srli_epi16(masked, 4); - const __m512i combined = _mm512_or_si512(masked, shifted); - - return _mm512_castsi256_si512(_mm512_cvtepi16_epi8(combined)); - } else { - const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); - const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); - - const __m512i shifted = _mm512_or_si512(even, _mm512_srli_epi16(odd, 8 - BIT_WIDTH)); - return mm512_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); - } - } -} - -/** - * @brief Pack 16 32-bit values for bit widths 17 through 24 using VBMI. - */ -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { - using tables = pack_tables_avx512_24; - - const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); - const __m512i masked = _mm512_and_si512(input, maskv); - - const __m512i permuted1 = _mm512_permutexvar_epi32(tables::get_permute1(), masked); - const __m512i permuted2 = _mm512_permutexvar_epi32(tables::get_permute2(), masked); - const __m512i permuted3 = _mm512_permutexvar_epi32(tables::get_permute3(), masked); - - const __m512i shifted1 = _mm512_sllv_epi32(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_sllv_epi32(permuted2, tables::get_shift2()); - const __m512i shifted3 = _mm512_srlv_epi32(permuted3, tables::get_shift3()); - - return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); -} -} // namespace m512 -} // namespace pernix::internal - -#endif // PERNIX_AVX512VBMI_PACKING_H diff --git a/include/pernix/x86/avx512vbmi/unpacking.h b/include/pernix/x86/avx512vbmi/unpacking.h deleted file mode 100644 index 1cfc29c..0000000 --- a/include/pernix/x86/avx512vbmi/unpacking.h +++ /dev/null @@ -1,500 +0,0 @@ -#ifndef PERNIX_AVX512VBMI_UNPACKING_H -#define PERNIX_AVX512VBMI_UNPACKING_H - -#include -#include - -namespace pernix::internal { -namespace m128 { -constexpr __mmask16 kAlternateByteMask16 = 0xAAAAULL; - -__always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0x00ff); - const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); - const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); - return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); -} - -__always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { - const __m128i mask = _mm_set1_epi16(0xff00); - const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); - const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); - return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); -} - -__always_inline static __m128i _mm_slli_epi8(const __m128i a, const int8_t imm8) { - return _mm_sllv_epi8(a, _mm_set1_epi8(imm8)); -} - -__always_inline static __m128i _mm_srli_epi8(const __m128i a, const int imm8) { - const __m128i lo_mask = _mm_set1_epi16(0x00ff); - const __m128i hi_mask = _mm_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m128i lo = _mm_srl_epi16(_mm_and_si128(a, lo_mask), shift); - const __m128i hi = _mm_and_si128(_mm_srl_epi16(a, shift), hi_mask); - - return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); -} - -__always_inline static __m128i _mm_srai_epi8(const __m128i a, const int8_t imm8) { - const __m128i lo_mask = _mm_set1_epi16(0x00ff); - const __m128i hi_mask = _mm_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m128i hi = _mm_and_si128(_mm_sra_epi16(a, shift), hi_mask); - - const __m128i lo_as_hi = _mm_slli_epi16(_mm_and_si128(a, lo_mask), 8); - const __m128i lo = _mm_and_si128(_mm_srli_epi16(_mm_sra_epi16(lo_as_hi, shift), 8), lo_mask); - - return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - if constexpr (BIT_WIDTH == 1) { - const auto value = static_cast<__mmask16>(_mm_cvtsi128_si64(input)); - const __m128i source = _mm_movm_epi8(value); - const __m128i unpacked = _mm_abs_epi8(source); - return unpacked; - } else if constexpr (BIT_WIDTH == 2) { - __m128i values_shift0 = input; - __m128i values_shift2 = _mm_srli_epi16(values_shift0, 2); - const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); - const __m128i values_shift6 = _mm_srli_epi16(values_shift0, 6); - - __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift2); - values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift2); - values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); - - interleave_tmp = _mm_unpacklo_epi8(values_shift4, values_shift6); - values_shift2 = _mm_unpackhi_epi8(values_shift4, values_shift6); - values_shift2 = _mm_unpacklo_epi64(interleave_tmp, values_shift2); - - interleave_tmp = _mm_unpacklo_epi16(values_shift0, values_shift2); - values_shift0 = _mm_unpackhi_epi16(values_shift0, values_shift2); - values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); - values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); - - values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0303)); - - return values_shift0; - } else if constexpr (BIT_WIDTH == 4) { - __m128i values_shift0 = input; - const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); - - const __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift4); - values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift4); - values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); - values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); - - values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0F0F)); - - return values_shift0; - } else { - using tables = unpack_tables_avx512_8; - - const __m128i permuted1 = _mm_permutexvar_epi8(tables::get_permute1(), input); - const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); - - const __m128i shifted1 = _mm_srlv_epi8(permuted1, tables::get_shift1()); - const __m128i shifted2 = _mm_sllv_epi8(permuted2, tables::get_shift2()); - - const __mmask16 spill_mask = _mm_cmpneq_epi8_mask(tables::get_shift2(), _mm_setzero_si128()); - __m128i combined = _mm_or_si128(shifted1, _mm_maskz_mov_epi8(spill_mask, shifted2)); - - constexpr uint32_t shift = 8 - BIT_WIDTH; - combined = _mm_slli_epi8(combined, shift); - if (SIGN_VALUES) { - combined = _mm_srai_epi8(combined, shift); - } else { - combined = _mm_srli_epi8(combined, shift); - } - - return combined; - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = unpack_tables_avx512_16; - - const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute1(), input); - - __m128i shifted = _mm_srlv_epi16(permuted, tables::get_shift1()); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); - const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); - shifted = _mm_or_si128(shifted, shifted2); - } - - constexpr uint32_t shift = 16 - BIT_WIDTH; - shifted = _mm_slli_epi16(shifted, shift); - if (SIGN_VALUES) { - shifted = _mm_srai_epi16(shifted, shift); - } else { - shifted = _mm_srli_epi16(shifted, shift); - } - - return shifted; - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m128i mm_unpack_epi32_avx512vbmi_17to24(const __m128i& input) { - using tables = unpack_tables_avx512_24; - - const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute(), input); - - constexpr uint32_t shift = 32 - BIT_WIDTH; - __m128i shifted = _mm_sllv_epi32(permuted, tables::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm_srai_epi32(shifted, shift); - } else { - shifted = _mm_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace m128 - -namespace m256 { -constexpr __mmask32 kAlternateByteMask32 = 0xAAAAAAAAULL; - -__always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0x00ff); - const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); - const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); - return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); -} - -__always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { - const __m256i mask = _mm256_set1_epi16(0xff00); - const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); - const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); - return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); -} - -__always_inline static __m256i _mm256_slli_epi8(const __m256i a, const int8_t imm8) { - return _mm256_sllv_epi8(a, _mm256_set1_epi8(imm8)); -} - -__always_inline static __m256i _mm256_srli_epi8(const __m256i a, const int8_t imm8) { - const __m256i lo_mask = _mm256_set1_epi16(0x00ff); - const __m256i hi_mask = _mm256_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m256i lo = _mm256_srl_epi16(_mm256_and_si256(a, lo_mask), shift); - const __m256i hi = _mm256_and_si256(_mm256_srl_epi16(a, shift), hi_mask); - - return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); -} - -__always_inline static __m256i _mm256_srai_epi8(const __m256i a, const int8_t imm8) { - const __m256i lo_mask = _mm256_set1_epi16(0x00ff); - const __m256i hi_mask = _mm256_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m256i hi = _mm256_and_si256(_mm256_sra_epi16(a, shift), hi_mask); - - const __m256i lo_as_hi = _mm256_slli_epi16(_mm256_and_si256(a, lo_mask), 8); - const __m256i lo = _mm256_and_si256(_mm256_srli_epi16(_mm256_sra_epi16(lo_as_hi, shift), 8), lo_mask); - - return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - if constexpr (BIT_WIDTH == 1) { - const auto value = static_cast<__mmask32>(_mm_cvtsi128_si64(_mm256_castsi256_si128(input))); - const __m256i source = _mm256_movm_epi8(value); - const __m256i unpacked = _mm256_abs_epi8(source); - return unpacked; - } else if constexpr (BIT_WIDTH == 2) { - __m256i values_shift0 = input; - __m256i values_shift2 = _mm256_srli_epi16(values_shift0, 2); - const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); - const __m256i values_shift6 = _mm256_srli_epi16(values_shift0, 6); - - __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift2); - values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift2); - values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); - - interleave_tmp = _mm256_unpacklo_epi8(values_shift4, values_shift6); - values_shift2 = _mm256_unpackhi_epi8(values_shift4, values_shift6); - values_shift2 = _mm256_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); - - interleave_tmp = _mm256_unpacklo_epi16(values_shift0, values_shift2); - values_shift0 = _mm256_unpackhi_epi16(values_shift0, values_shift2); - values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); - values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); - - values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0303)); - - return values_shift0; - } else if constexpr (BIT_WIDTH == 4) { - __m256i values_shift0 = input; - const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); - - __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift4); - values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift4); - values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); - values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); - - values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0F0F)); - - return values_shift0; - } else { - using tables = unpack_tables_avx512_8; - - const __m256i permuted1 = _mm256_permutexvar_epi8(tables::get_permute1(), input); - const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); - - const __m256i shifted1 = _mm256_srlv_epi8(permuted1, tables::get_shift1()); - const __m256i shifted2 = _mm256_sllv_epi8(permuted2, tables::get_shift2()); - - const __mmask32 spill_mask = _mm256_cmpneq_epi8_mask(tables::get_shift2(), _mm256_setzero_si256()); - __m256i combined = _mm256_or_si256(shifted1, _mm256_maskz_mov_epi8(spill_mask, shifted2)); - - constexpr uint32_t shift = 8 - BIT_WIDTH; - combined = _mm256_slli_epi8(combined, shift); - if (SIGN_VALUES) { - combined = _mm256_srai_epi8(combined, shift); - } else { - combined = _mm256_srli_epi8(combined, shift); - } - - return combined; - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = unpack_tables_avx512_16; - - const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute1(), input); - - __m256i shifted = _mm256_srlv_epi16(permuted, tables::get_shift1()); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); - const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); - shifted = _mm256_or_si256(shifted, shifted2); - } - - constexpr uint32_t shift = 16 - BIT_WIDTH; - shifted = _mm256_slli_epi16(shifted, shift); - if (SIGN_VALUES) { - shifted = _mm256_srai_epi16(shifted, shift); - } else { - shifted = _mm256_srli_epi16(shifted, shift); - } - - return shifted; - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m256i mm256_unpack_epi32_avx512vbmi_17to24(const __m256i& input) { - using tables = unpack_tables_avx512_24; - - const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute(), input); - - constexpr uint32_t shift = 32 - BIT_WIDTH; - __m256i shifted = _mm256_sllv_epi32(permuted, tables::get_shift()); - if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { - shifted = _mm256_srai_epi32(shifted, shift); - } else { - shifted = _mm256_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace m256 - -namespace m512 { -constexpr __mmask64 kAlternateByteMask64 = 0xAAAAAAAAAAAAAAAAULL; - -__always_inline static __m512i _mm512_srlv_epi8(const __m512i a, const __m512i count) { - const __m512i mask = _mm512_set1_epi16(0x00ff); - const __m512i low_half = _mm512_srlv_epi16(_mm512_and_si512(mask, a), _mm512_and_si512(mask, count)); - const __m512i high_half = _mm512_srlv_epi16(a, _mm512_srli_epi16(count, 8)); - return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); -} - -__always_inline static __m512i _mm512_sllv_epi8(const __m512i a, const __m512i count) { - const __m512i mask = _mm512_set1_epi16(0xff00); - const __m512i low_half = _mm512_sllv_epi16(a, _mm512_andnot_si512(mask, count)); - const __m512i high_half = _mm512_sllv_epi16(_mm512_and_si512(mask, a), _mm512_srli_epi16(count, 8)); - return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); -} - -__always_inline static __m512i _mm512_slli_epi8(const __m512i a, const int8_t imm8) { - return _mm512_sllv_epi8(a, _mm512_set1_epi8(imm8)); -} - -__always_inline static __m512i _mm512_srli_epi8(const __m512i a, const int8_t imm8) { - const __m512i lo_mask = _mm512_set1_epi16(0x00ff); - const __m512i hi_mask = _mm512_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m512i lo = _mm512_srl_epi16(_mm512_and_si512(a, lo_mask), shift); - const __m512i hi = _mm512_and_si512(_mm512_srl_epi16(a, shift), hi_mask); - - return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); -} - -__always_inline static __m512i _mm512_srai_epi8(const __m512i a, const int8_t imm8) { - const __m512i lo_mask = _mm512_set1_epi16(0x00ff); - const __m512i hi_mask = _mm512_set1_epi16(0xff00); - const __m128i shift = _mm_cvtsi32_si128(imm8); - - const __m512i hi = _mm512_and_si512(_mm512_sra_epi16(a, shift), hi_mask); - - const __m512i lo_as_hi = _mm512_slli_epi16(_mm512_and_si512(a, lo_mask), 8); - const __m512i lo = _mm512_and_si512(_mm512_srli_epi16(_mm512_sra_epi16(lo_as_hi, shift), 8), lo_mask); - - return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); -} - -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i& input) { - if constexpr (BIT_WIDTH == 8) { - return input; - } else { - if constexpr (BIT_WIDTH == 1) { - const auto value = static_cast<__mmask64>(_mm_cvtsi128_si64(_mm512_castsi512_si128(input))); - const __m512i source = _mm512_movm_epi8(value); - const __m512i unpacked = _mm512_abs_epi8(source); - return unpacked; - } else if constexpr (BIT_WIDTH == 2) { - __m512i values_shift0 = input; - __m512i values_shift2 = _mm512_srli_epi16(values_shift0, 2); - const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); - const __m512i values_shift6 = _mm512_srli_epi16(values_shift0, 6); - - __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift2); - values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift2); - values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); - - interleave_tmp = _mm512_unpacklo_epi8(values_shift4, values_shift6); - values_shift2 = _mm512_unpackhi_epi8(values_shift4, values_shift6); - values_shift2 = _mm512_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); - - interleave_tmp = _mm512_unpacklo_epi16(values_shift0, values_shift2); - values_shift0 = _mm512_unpackhi_epi16(values_shift0, values_shift2); - values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x88); - values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); - - values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0303)); - - return values_shift0; - } else if constexpr (BIT_WIDTH == 4) { - __m512i values_shift0 = input; - const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); - - __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift4); - values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift4); - values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x44); - values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); - - values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0F0F)); - - return values_shift0; - } else { - using tables = unpack_tables_avx512_8; - - const __m512i permuted1 = _mm512_permutexvar_epi8(tables::get_permute1(), input); - const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); - - const __m512i shifted1 = _mm512_srlv_epi8(permuted1, tables::get_shift1()); - const __m512i shifted2 = _mm512_sllv_epi8(permuted2, tables::get_shift2()); - - const __mmask64 spill_mask = _mm512_cmpneq_epi8_mask(tables::get_shift2(), _mm512_setzero_si512()); - __m512i combined = _mm512_or_si512(shifted1, _mm512_maskz_mov_epi8(spill_mask, shifted2)); - - constexpr uint32_t shift = 8 - BIT_WIDTH; - combined = _mm512_slli_epi8(combined, shift); - if (SIGN_VALUES) { - combined = _mm512_srai_epi8(combined, shift); - } else { - combined = _mm512_srli_epi8(combined, shift); - } - - return combined; - } - } -} - -template - requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i& input) { - if constexpr (BIT_WIDTH == 16) { - return input; - } else { - using tables = unpack_tables_avx512_16; - - const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute1(), input); - __m512i shifted = _mm512_srlv_epi16(permuted, tables::get_shift1()); - - if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { - const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); - const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); - shifted = _mm512_or_si512(shifted, shifted2); - } - - constexpr uint32_t shift = 16 - BIT_WIDTH; - shifted = _mm512_slli_epi16(shifted, shift); - if (SIGN_VALUES) { - shifted = _mm512_srai_epi16(shifted, shift); - } else { - shifted = _mm512_srli_epi16(shifted, shift); - } - - return shifted; - } -} - -template - requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline __m512i mm512_unpack_epi32_avx512vbmi_17to24(const __m512i& input) { - using tables = unpack_tables_avx512_24; - - const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute(), input); - __m512i shifted = _mm512_sllv_epi32(permuted, tables::get_shift()); - - constexpr uint32_t shift = 32 - BIT_WIDTH; - if constexpr (SIGN_VALUES) { - shifted = _mm512_srai_epi32(shifted, shift); - } else { - shifted = _mm512_srli_epi32(shifted, shift); - } - - return shifted; -} -} // namespace m512 -} // namespace pernix::internal - -#endif // PERNIX_AVX512VBMI_UNPACKING_H diff --git a/include/pernix/x86/bmi2/compression.h b/include/pernix/x86/bmi2/compression.h deleted file mode 100644 index ca4b976..0000000 --- a/include/pernix/x86/bmi2/compression.h +++ /dev/null @@ -1,371 +0,0 @@ -#ifndef PERNIX_BMI2_COMPRESSION_H -#define PERNIX_BMI2_COMPRESSION_H - -#include -#include -#include - -#include -#include -#include - -namespace pernix { -namespace internal { -/** - * @brief Build the masks and shift constants used by the BMI2 packers. - * - * @tparam BIT_WIDTH bit width per packed value. - * @return std::tuple mask tuple used by the BMI2 helpers. - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) -static constexpr std::tuple pack_avx2_bmi2_constants() { - uint32_t mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; - uint64_t pext_mask; - uint16_t shift1 = BIT_WIDTH * 4; - uint16_t shift2 = 64 - shift1; - - if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { - pext_mask = 0x0101010101010101ULL * mask; - } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { - pext_mask = 0x0001000100010001ULL * mask; - } else { - pext_mask = 0x0000000100000001ULL * mask; - } - - return { - mask, - pext_mask, - shift1, - shift2, - }; -} - -/** - * @brief Pack four 32-bit values with BMI2 extract instructions. - * - * @tparam BIT_WIDTH bit width per packed value. - * @param input SIMD register containing four quantized values. - * @return __m128i packed bitstream in the low bytes of the result. - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) -static inline auto mm_pack_epi32_bmi2(const __m128i& input) -> __m128i { - const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); - - if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 16) { - const __m128i packed = _mm_packs_epi32(input, _mm_setzero_si128()); - const uint64_t value = _pext_u64(_mm_extract_epi64(packed, 0), pext_mask); - - const __m128i result = _mm_set_epi64x(0, value); - return result; - } else { - alignas(16) uint64_t values[2]; - values[0] = _pext_u64(_mm_extract_epi64(input, 0), pext_mask); - - const uint64_t temp_combined = _pext_u64(_mm_extract_epi64(input, 1), pext_mask); - values[1] = temp_combined >> shift2; - values[0] |= (temp_combined << shift1); - - const __m128i result = _mm_set_epi64x(static_cast(values[1]), static_cast(values[0])); - return result; - } -} - -/** - * @brief Pack eight 32-bit values with BMI2 extract instructions. - * - * @tparam BIT_WIDTH bit width per packed value. - * @param input SIMD register containing eight quantized values. - * @return __m256i packed bitstream in the low bytes of the result. - */ -template - requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) -static inline auto mm256_pack_epi32_bmi2(const __m256i& input) -> __m256i { - const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); - - if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { - const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); - const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); - const __m256i packed8 = _mm256_packs_epi16(permuted, _mm256_setzero_si256()); - const uint64_t value = _pext_u64(_mm256_extract_epi64(packed8, 0), pext_mask); - - const __m256i result = _mm256_setr_epi64x(static_cast(value), 0, 0, 0); - return result; - } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { - const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); - alignas(16) int64_t values[2] = {}; - values[0] = _pext_u64(_mm256_extract_epi64(packed16, 0), pext_mask); - - const uint64_t temp_combined = _pext_u64(_mm256_extract_epi64(packed16, 2), pext_mask); - values[1] = temp_combined >> shift2; - if constexpr (BIT_WIDTH != 16) { - values[0] |= static_cast(temp_combined << shift1); - } - - const __m256i result = _mm256_setr_epi64x(values[0], values[1], 0, 0); - return result; - } else { - constexpr uint32_t chunk_bits = BIT_WIDTH * 2; // bits extracted per 64-bit lane - static_assert(chunk_bits < 64); - - constexpr uint64_t chunk_mask = (chunk_bits == 64) ? ~uint64_t{0} : ((uint64_t{1} << chunk_bits) - 1); - - const uint64_t x0 = _pext_u64(_mm256_extract_epi64(input, 0), pext_mask) & chunk_mask; - const uint64_t x1 = _pext_u64(_mm256_extract_epi64(input, 1), pext_mask) & chunk_mask; - const uint64_t x2 = _pext_u64(_mm256_extract_epi64(input, 2), pext_mask) & chunk_mask; - const uint64_t x3 = _pext_u64(_mm256_extract_epi64(input, 3), pext_mask) & chunk_mask; - - uint64_t out0 = 0; - uint64_t out1 = 0; - uint64_t out2 = 0; - - auto append_bits = [&](uint64_t value, uint32_t bit_offset) { - const uint32_t word = bit_offset >> 6; // / 64 - const uint32_t off = bit_offset & 63; // % 64 - - if (word == 0) { - out0 |= value << off; - if (off + chunk_bits > 64) { - out1 |= value >> (64 - off); - } - } else if (word == 1) { - out1 |= value << off; - if (off + chunk_bits > 64) { - out2 |= value >> (64 - off); - } - } else { - out2 |= value << off; - } - }; - - append_bits(x0, 0 * chunk_bits); - append_bits(x1, 1 * chunk_bits); - append_bits(x2, 2 * chunk_bits); - append_bits(x3, 3 * chunk_bits); - - return _mm256_setr_epi64x(static_cast(out0), static_cast(out1), static_cast(out2), 0); - } -} -} // namespace internal - -/** - * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_bmi2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - std::memset(output, 0, BLOCK_SIZE); - - const __m256 scale_v = _mm256_set1_ps(scale); -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256 source = _mm256_loadu_ps(input); - const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); - const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); - const __m256i packed = internal::mm256_pack_epi32_bmi2(packed_input); - std::memcpy(output, &packed, BIT_WIDTH); - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining) { - std::vector block_values(remaining); -#pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_ps_epi32(input[i], scale))); - } - - internal::pack_epi32_fallback(block_values, output); - } - - return 0; -} - -/** - * @brief Compress a single block of double values using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_block_bmi2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - std::memset(output, 0, BLOCK_SIZE); - - const __m256d scale_v = _mm256_set1_pd(scale); -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256d source1 = _mm256_loadu_pd(input); - const __m256d source2 = _mm256_loadu_pd(input + 4); - const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); - const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); - __m256i combined = _mm256_castsi128_si256(quantized1); - combined = _mm256_inserti128_si256(combined, quantized2, 1); - const __m256i packed = internal::mm256_pack_epi32_bmi2(internal::mm256_clamp_signed_epi32(combined)); - std::memcpy(output, &packed, BIT_WIDTH); - input += 8; - output += BIT_WIDTH; - } - - if constexpr (remaining) { - std::vector block_values(remaining); -#pragma GCC unroll 8 - for (uint32_t i = 0; i < remaining; i++) { - block_values[i] = - static_cast(internal::clamp_signed_quantized(internal::quantize_pd_epi64(input[i], scale))); - } - - internal::pack_epi32_fallback(block_values, output); - } - return 0; -} - -/** - * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_bmi2(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const float_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_bmi2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} - -/** - * @brief Compress multiple blocks of double values using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_compress_blocks_bmi2(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - const double_t* block_input = input; - uint8_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_compress_block_bmi2(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; - block_output += BLOCK_SIZE; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_block_bmi2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_block_f64_bmi2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input float values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_bmi2(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 24). - * @param input pointer to the start of the input double values. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where compressed bytes will be stored. - * @param blocks number of 512-bit blocks to compress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_compress_blocks_f64_bmi2(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_BMI2_COMPRESSION_H diff --git a/include/pernix/x86/bmi2/decompression.h b/include/pernix/x86/bmi2/decompression.h deleted file mode 100644 index 443d673..0000000 --- a/include/pernix/x86/bmi2/decompression.h +++ /dev/null @@ -1,411 +0,0 @@ -#ifndef PERNIX_BMI2_DECOMPRESSION_H -#define PERNIX_BMI2_DECOMPRESSION_H - -#include -#include - -#include -#include -#include -#include - -namespace pernix { -namespace internal { -/** - * @brief Sign-extend packed values after BMI2 expansion into 32-bit lanes. - * - * @tparam BIT_WIDTH original encoded bit width. - * @param source register containing unpacked values. - * @return __m128i sign-extended values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m128i mm_sign_extend32(__m128i source) { - if constexpr (BIT_WIDTH == 1) { - // Keep 1-bit values as 0/1 to match fallback decoding semantics. - return source; - } - - constexpr uint16_t shift = 32 - BIT_WIDTH; - source = _mm_slli_epi32(source, shift); - return _mm_srai_epi32(source, shift); -} - -/** - * @brief Sign-extend packed values after BMI2 expansion into eight 32-bit lanes. - * - * @tparam BIT_WIDTH original encoded bit width. - * @param source register containing unpacked values. - * @return __m256i sign-extended values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m256i mm256_sign_extend32(__m256i source) { - if constexpr (BIT_WIDTH == 1) { - // Keep 1-bit values as 0/1 to match fallback decoding semantics. - return source; - } - - constexpr uint16_t shift = 32 - BIT_WIDTH; - source = _mm256_slli_epi32(source, shift); - return _mm256_srai_epi32(source, shift); -} - -/** - * @brief Unpack four values from a BMI2-packed input buffer. - * - * @tparam BIT_WIDTH bit width per packed value. - * @tparam SIGN_VALUES whether to sign-extend the unpacked values. - * @param input pointer to the packed input buffer. - * @return __m128i unpacked values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH > 0 && BIT_WIDTH <= 24) -__m128i mm_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { - constexpr uint32_t mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; - constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; - - __m128i result; - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - constexpr uint64_t pdep_mask = 0x0101010101010101ULL * mask; - - uint32_t temp_value = 0; - std::memcpy(&temp_value, input, packed_bytes); - - const int32_t value = _pdep_u32(temp_value, static_cast(pdep_mask)); - const __m128i source = _mm_insert_epi32(_mm_setzero_si128(), value, 0); - - result = _mm_cvtepi8_epi32(source); - } else if constexpr (BIT_WIDTH == 16) { - const __m128i source = _mm_loadu_si64(input); - result = _mm_cvtepi16_epi32(source); - } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { - constexpr uint64_t pdep_mask = 0x0001000100010001ULL * mask; - - uint64_t temp_value = 0; - std::memcpy(&temp_value, input, packed_bytes); - - const int64_t value = _pdep_u64(temp_value, pdep_mask); - const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); - - result = _mm_cvtepi16_epi32(source); - } else { - constexpr uint64_t pdep_mask = 0x0000000100000001ULL * mask; - constexpr uint32_t shift1 = BIT_WIDTH * 2; - constexpr uint32_t shift2 = 64 - shift1; - - alignas(16) uint64_t temp_values[2]{}; - std::memcpy(temp_values, input, packed_bytes); - - alignas(16) int64_t values[2]; - values[0] = _pdep_u64(temp_values[0], pdep_mask); - values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); - - result = _mm_set_epi64x(values[1], values[0]); - } - - if constexpr (SIGN_VALUES) { - result = internal::mm_sign_extend32(result); - } - return result; -} - -/** - * @brief Unpack eight values from a BMI2-packed input buffer. - * - * @tparam BIT_WIDTH bit width per packed value. - * @tparam SIGN_VALUES whether to sign-extend the unpacked values. - * @param input pointer to the packed input buffer. - * @return __m256i unpacked values. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) -__m256i mm256_unpack_epi32_bmi2(const uint8_t* __restrict__ input) { - constexpr uint32_t mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; - constexpr std::size_t packed_bytes = BIT_WIDTH; - - __m256i result; - if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { - constexpr uint64_t pdep_mask = 0x0101010101010101ULL * mask; - - uint64_t temp_value = 0; - std::memcpy(&temp_value, input, packed_bytes); - - const int64_t value = _pdep_u64(temp_value, pdep_mask); - const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); - - result = _mm256_cvtepi8_epi32(source); - } else if constexpr (BIT_WIDTH == 16) { - const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); - result = _mm256_cvtepi16_epi32(source); - } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { - constexpr uint64_t pdep_mask = 0x0001000100010001ULL * mask; - constexpr uint64_t shift1 = BIT_WIDTH * 4; - constexpr uint64_t shift2 = 64 - shift1; - - alignas(16) uint64_t temp_values[2]{}; - std::memcpy(temp_values, input, packed_bytes); - - alignas(16) int64_t values[2]; - values[0] = _pdep_u64(temp_values[0], pdep_mask); - values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); - - const __m128i source = _mm_set_epi64x(values[1], values[0]); - result = _mm256_cvtepi16_epi32(source); - } else { - constexpr uint64_t pdep_mask = 0x0000000100000001ULL * mask; - constexpr uint32_t shift1 = BIT_WIDTH * 2; - constexpr uint32_t shift2 = 64 - shift1; - - alignas(16) uint64_t temp_values[4]{}; - std::memcpy(temp_values, input, packed_bytes); - - if constexpr ((BIT_WIDTH % 2) == 0) { - std::memcpy(temp_values + 2, input + BIT_WIDTH / 2, packed_bytes - BIT_WIDTH / 2); - } else { - constexpr uint32_t second_group_bit_offset = BIT_WIDTH * 4; - constexpr uint32_t second_group_byte_offset = second_group_bit_offset / 8; - constexpr uint32_t second_group_shift = second_group_bit_offset % 8; - - alignas(16) uint64_t raw_values[2]{}; - std::memcpy(raw_values, input + second_group_byte_offset, packed_bytes - second_group_byte_offset); - - temp_values[2] = (raw_values[0] >> second_group_shift) | (raw_values[1] << (64 - second_group_shift)); - temp_values[3] = raw_values[1] >> second_group_shift; - } - - alignas(16) uint64_t values[4]; - values[0] = _pdep_u64((temp_values[0]), pdep_mask); - values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); - values[2] = _pdep_u64((temp_values[2]), pdep_mask); - values[3] = _pdep_u64((temp_values[2] >> shift1) | (temp_values[3] << shift2), pdep_mask); - - result = _mm256_set_epi64x(static_cast(values[3]), static_cast(values[2]), static_cast(values[1]), - static_cast(values[0])); - } - - if constexpr (SIGN_VALUES) { - result = internal::mm256_sign_extend32(result); - } - return result; -} -} // namespace internal - -/** - * @brief Decompress a single 512\-bit block using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 16). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_bmi2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - - const __m256 scale_v = _mm256_set1_ps(scale); -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); - const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); - _mm256_storeu_ps(output, dequantized); - input += BIT_WIDTH; - output += 8; - } - - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi32(tail_values[i], scale); - } - } - - return 0; -} - -/** - * @brief Decompress a single block to double values using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 16). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_block_bmi2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint8_t remaining = elements_per_block - iterations_8 * 8; - const __m256d scale_v = _mm256_set1_pd(scale); -#pragma GCC unroll 4 - for (uint32_t iter = 0; iter < iterations_8; iter++) { - const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); - const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); - const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); - - const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); - const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); - - _mm256_storeu_pd(output, dequantized1); - _mm256_storeu_pd(output + 4, dequantized2); - - input += BIT_WIDTH; - output += 8; - } - - if constexpr (remaining > 0) { - const std::vector tail_values = internal::unpack_epi32_fallback(input, remaining); - for (uint32_t i = 0; i < remaining; i++) { - output[i] = internal::dequantize_epi64(tail_values[i], scale); - } - } - - return 0; -} - -/** - * @brief Decompress multiple 512\-bit blocks using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_bmi2(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_bmi2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} - -/** - * @brief Decompress multiple blocks to double values using AVX2 and BMI2 instructions. - * - * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). - * @tparam SIGN_VALUES whether the values are signed or unsigned. - * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). - * - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed double values will be stored. - * @param blocks number of blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int mm256_decompress_blocks_bmi2(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; block++) { - mm256_decompress_block_bmi2(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } - - return 0; -} -} // namespace pernix - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -/** - * @brief Decompress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_block_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -/** - * @brief Decompress a single 512-bit block using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed block. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_block_f64_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); - -/** - * @brief Decompress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_blocks_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - uint32_t blocks); - -/** - * @brief Decompress multiple 512-bit blocks using AVX2 and BMI2 instructions. - * - * @param bit_width bit width per value in the packed representation (1 to 16). - * @param input pointer to the start of the compressed data. - * @param scale scaling factor used during quantization. - * @param output pointer to the output buffer where decompressed float values will be stored. - * @param blocks number of 512-bit blocks to decompress. - * @return int status code (0 for success). - * - * @note This function requires AVX2 and BMI2 support. - */ -int mm256_decompress_blocks_f64_bmi2(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); - -#ifdef __cplusplus -} -} // namespace pernix -#endif - -#endif // PERNIX_BMI2_DECOMPRESSION_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b7e8f23..aab5c67 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,42 +1,275 @@ include(GNUInstallDirs) +include(CMakePackageConfigHelpers) +add_library(pernix SHARED) +target_sources(pernix + PRIVATE + pernix.cpp + + fallback/fallback_compression.cpp + fallback/fallback_decompression.cpp + + dispatch/select.cpp +) + +target_compile_features(pernix PUBLIC cxx_std_20) + +set_target_properties(pernix PROPERTIES + OUTPUT_NAME "pernix" + VERSION ${NORMALIZED_VERSION} + LINKER_LANGUAGE CXX +) + +target_include_directories(pernix + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/internal +) + +target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_LIB=1 +) + +if (PERNIX_USE_SIMDE) + # target_link_libraries(pernix PUBLIC simde::simde) + target_compile_definitions(pernix PUBLIC PERNIX_USE_SIMDE=1) +endif () + +if (PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + dispatch/cpu_features_x86.cpp + ) + + if (MSVC) + set_source_files_properties( + dispatch/cpu_features_x86.cpp + PROPERTIES + COMPILE_OPTIONS "/arch:AVX512" + ) + else () + set_source_files_properties( + dispatch/cpu_features_x86.cpp + PROPERTIES + COMPILE_OPTIONS "-mavx;-mbmi;-mbmi2;-mavx2;-mavx512f;-mavx512bw;-mavx512vl;-mavx512dq;-mavx512cd;-mavx512vbmi" + ) + endif () +endif () + +if (PERNIX_TARGET_IS_ARM64) + target_sources(pernix + PRIVATE + dispatch/cpu_features_arm.cpp + ) + + set_target_properties(pernix PROPERTIES + COMPILE_OPTIONS "-march=armv8-a+simd+sve2" + ) +endif () + +if (PERNIX_ENABLE_X86_BMI2 AND PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + x86/bmi2/bmi2_compression.cpp + x86/bmi2/bmi2_decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_X86_BMI2=1 + ) + + if (MSVC) + # BMI2 is not exposed with MSVC + else () + set_source_files_properties( + x86/bmi2/bmi2_compression.cpp + x86/bmi2/bmi2_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-mbmi;-mbmi2;-mavx2" + ) + endif () + +endif () + +if (PERNIX_ENABLE_X86_AVX2 AND PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + x86/avx2/avx2_compression.cpp + x86/avx2/avx2_decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_X86_AVX2=1 + ) + + if (MSVC) + set_source_files_properties( + x86/avx2/avx2_compression.cpp + x86/avx2/avx2_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "/arch:AVX2" + ) + else () + set_source_files_properties( + x86/avx2/avx2_compression.cpp + x86/avx2/avx2_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-mavx2" + ) + endif () +endif () + +if (PERNIX_ENABLE_X86_AVX512VBMI AND PERNIX_TARGET_IS_X86) + target_sources(pernix + PRIVATE + x86/avx512vbmi/avx512vbmi_compression.cpp + x86/avx512vbmi/avx512vbmi_decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_X86_AVX512_VBMI=1 + ) + + if (MSVC) + set_source_files_properties( + x86/avx512vbmi/avx512vbmi_compression.cpp + x86/avx512vbmi/avx512vbmi_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "/arch:AVX512" + ) + else () + set_source_files_properties( + x86/avx512vbmi/avx512vbmi_compression.cpp + x86/avx512vbmi/avx512vbmi_decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-mavx512f;-mavx512bw;-mavx512vl;-mavx512dq;-mavx512cd;-mavx512vbmi" + ) + endif () +endif () + +if (PERNIX_ENABLE_ARM64_NEON AND PERNIX_TARGET_IS_ARM64) + target_sources(pernix + PRIVATE + arm64/neon/compression.cpp + arm64/neon/decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_ARM64_NEON=1 + ) + + set_source_files_properties( + arm64/neon/compression.cpp + arm64/neon/decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-march=armv8-a+simd" + ) +endif () + +if (PERNIX_ENABLE_ARM64_SVE2 AND PERNIX_TARGET_IS_ARM64) + target_sources(pernix + PRIVATE + arm64/sve2/compression.cpp + arm64/sve2/decompression.cpp + ) + + target_compile_definitions(pernix + PRIVATE + PERNIX_BUILD_ARM64_SVE2=1 + ) + + set_source_files_properties( + arm64/sve2/compression.cpp + arm64/sve2/decompression.cpp + PROPERTIES + COMPILE_OPTIONS "-march=armv8.2-a+simd+sve2" + ) +endif () + +#[===[ file(GLOB_RECURSE PERNIX_COMMON_SOURCES ./fallback/*.cpp - ./pernix.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/*.h + pernix.cpp + dispatch/select.cpp ) set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) -set(PERNIX_TARGET_IS_X86 OFF) -if (PERNIX_USE_SIMDE OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_TARGET_IS_X86 ON) -endif () - -if (PERNIX_TARGET_IS_X86) +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES + CONFIGURE_DEPENDS ./x86/*.cpp - ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_X86_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_SOURCES + CONFIGURE_DEPENDS + ./arm64/neon/*.cpp + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_NEON_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_SOURCES + CONFIGURE_DEPENDS + ./arm64/sve/*.cpp + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_SOURCES + CONFIGURE_DEPENDS + ./arm64/sve2/*.cpp + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE2_SOURCES}) endif () + add_library(pernix SHARED ${PERNIX_SOURCES}) +add_library(pernix::pernix ALIAS pernix) set_target_properties(pernix PROPERTIES OUTPUT_NAME "pernix" VERSION ${NORMALIZED_VERSION} ) -target_include_directories(pernix PUBLIC +target_compile_features(pernix PUBLIC cxx_std_20) +target_compile_options(pernix PRIVATE ${PERNIX_PRIVATE_COMPILE_OPTIONS}) +target_include_directories(pernix + PUBLIC $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/internal ) +if (PERNIX_ENABLE_LTO) + set_target_properties(pernix PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) +endif () if (PERNIX_USE_SIMDE) target_link_libraries(pernix PUBLIC simde::simde) target_compile_definitions(pernix PUBLIC PERNIX_USE_SIMDE=1) endif () +target_compile_definitions(pernix PUBLIC "PERNIX_BACKEND_${PERNIX_SELECTED_ARCH_BACKEND}=1") + +if (PERNIX_DISABLE_BMI2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_BMI2=1) +endif () +if (PERNIX_DISABLE_AVX2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX2=1) +endif () +if (PERNIX_DISABLE_AVX512) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX512=1) +endif () + set_target_properties(pernix PROPERTIES LINKER_LANGUAGE CXX) configure_file( @@ -45,23 +278,50 @@ configure_file( if (PERNIX_ENABLE_INSTALL) install(TARGETS pernix + EXPORT pernixTargets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) - install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/pernix ${PROJECT_BINARY_DIR}/include/pernix + install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/pernix" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.*h" ) + if (PERNIX_USE_SIMDE AND PERNIX_BUNDLE_SIMDE_FOR_INSTALL) + install(DIRECTORY "${simde_SOURCE_DIR}/simde" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + FILES_MATCHING PATTERN "*.h" + ) + endif () + + configure_package_config_file( + "${PROJECT_SOURCE_DIR}/cmake/pernixConfig.cmake.in" + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + write_basic_package_version_file( + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + VERSION ${NORMALIZED_VERSION} + COMPATIBILITY SameMajorVersion + ) + install( + FILES + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + install( + EXPORT pernixTargets + FILE pernixTargets.cmake + NAMESPACE pernix:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) install( FILES ${PROJECT_BINARY_DIR}/pernix.pc DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig" ) - - add_custom_target(uninstall COMMAND xargs rm -vf < ${PROJECT_BINARY_DIR}/install_manifest.txt) endif () if (PERNIX_ENABLE_DOXYGEN) @@ -87,8 +347,10 @@ if (PERNIX_ENABLE_DOXYGEN) WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} COMMENT "Generating documentation with Doxygen." ) - if (BENCHMARK_ENABLE_INSTALL AND BENCHMARK_INSTALL_DOCS) + if (PERNIX_ENABLE_INSTALL AND PERNIX_INSTALL_DOCS) install(DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/html/" DESTINATION ${CMAKE_INSTALL_DOCDIR}) endif () endif () + +]===] \ No newline at end of file diff --git a/src/arm64/neon/compression.cpp b/src/arm64/neon/compression.cpp new file mode 100644 index 0000000..3594303 --- /dev/null +++ b/src/arm64/neon/compression.cpp @@ -0,0 +1,28 @@ +#include +#include + +namespace pernix::internal { + Kernel select_neon_compress_block_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } + + Kernel select_neon_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } + + Kernel select_neon_compress_block_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } + + Kernel select_neon_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"neon", nullptr}; + } +} diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp new file mode 100644 index 0000000..3083926 --- /dev/null +++ b/src/arm64/neon/decompression.cpp @@ -0,0 +1,208 @@ +#include +#include + +using pernix::arm64::neon::neon_decompress_block; +using pernix::arm64::neon::neon_decompress_blocks; + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &neon_decompress_block); \ + return Kernel("neon", &neon_decompress_block) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &neon_decompress_blocks); \ + return Kernel("neon", &neon_decompress_blocks) + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &neon_decompress_block); \ + return Kernel("neon", &neon_decompress_block) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("neon", &neon_decompress_blocks); \ + return Kernel("neon", &neon_decompress_blocks) + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"neon", nullptr}; \ + } \ + break + +Kernel select_neon_decompress_block_f32(const u8 bit_width, const u32 block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: + return {"neon", nullptr}; + } +} + +Kernel select_neon_decompress_blocks_f32(const u8 bit_width, const u32 block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: + return {"neon", nullptr}; + } +} + +Kernel select_neon_decompress_block_f64(const u8 bit_width, const u32 block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: + return {"neon", nullptr}; + } +} + +Kernel select_neon_decompress_blocks_f64(const u8 bit_width, const u32 block_size, bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: + return {"neon", nullptr}; + } +} + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/arm64/sve2/compression.cpp b/src/arm64/sve2/compression.cpp new file mode 100644 index 0000000..1839f12 --- /dev/null +++ b/src/arm64/sve2/compression.cpp @@ -0,0 +1,28 @@ +#include +#include + +namespace pernix::internal { + Kernel select_sve2_compress_block_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } + + Kernel select_sve2_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } + + Kernel select_sve2_compress_block_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } + + Kernel select_sve2_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + (void) bit_width; + (void) block_size; + return {"sve2", nullptr}; + } +} diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp new file mode 100644 index 0000000..8e65f6e --- /dev/null +++ b/src/arm64/sve2/decompression.cpp @@ -0,0 +1,208 @@ +#include +#include + +using pernix::arm64::sve2::sve2_decompress_block; +using pernix::arm64::sve2::sve2_decompress_blocks; + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &sve2_decompress_block); \ + return Kernel("sve2", &sve2_decompress_block) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &sve2_decompress_blocks); \ + return Kernel("sve2", &sve2_decompress_blocks) + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &sve2_decompress_block); \ + return Kernel("sve2", &sve2_decompress_block) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("sve2", &sve2_decompress_blocks); \ + return Kernel("sve2", &sve2_decompress_blocks) + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"sve2", nullptr}; \ + } \ + break + + Kernel select_sve2_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"sve2", nullptr}; + } + } + + Kernel select_sve2_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"sve2", nullptr}; + } + } + + Kernel select_sve2_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"sve2", nullptr}; + } + } + + Kernel select_sve2_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"sve2", nullptr}; + } + } + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/dispatch/cpu_features_arm.cpp b/src/dispatch/cpu_features_arm.cpp new file mode 100644 index 0000000..525962b --- /dev/null +++ b/src/dispatch/cpu_features_arm.cpp @@ -0,0 +1,34 @@ +#include + +#if defined(__linux__) && (defined(__aarch64__) || defined(_M_ARM64)) +#include +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif +#ifndef HWCAP2_SVE2 +#define HWCAP2_SVE2 (1 << 1) +#endif +#endif + +namespace pernix::internal { +CpuFeatures detect_cpu_features() { + CpuFeatures features{}; + + // neon +#if defined(__aarch64__) || defined(_M_ARM64) + features.neon = true; +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + features.neon = true; +#endif + + // sve / sve2 — runtime detection via getauxval on Linux +#if defined(__linux__) && (defined(__aarch64__) || defined(_M_ARM64)) + unsigned long hwcap = getauxval(AT_HWCAP); + unsigned long hwcap2 = getauxval(AT_HWCAP2); + features.sve = (hwcap & HWCAP_SVE) != 0; + features.sve2 = (hwcap2 & HWCAP2_SVE2) != 0; +#endif + + return features; +} +} \ No newline at end of file diff --git a/src/dispatch/cpu_features_x86.cpp b/src/dispatch/cpu_features_x86.cpp new file mode 100644 index 0000000..973df4d --- /dev/null +++ b/src/dispatch/cpu_features_x86.cpp @@ -0,0 +1,95 @@ +#include + +#include + +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) +#if defined(_MSC_VER) +#include +#else +#include +#include +#endif + + +namespace pernix::internal { +namespace { +#if defined(_MSC_VER) + +void cpuid(int out[4], int leaf, int subleaf) { + __cpuidex(out, leaf, subleaf); +} + +u64 xgetbv(unsigned int index) { + return _xgetbv(index); +} + +#else + +void cpuid(int out[4], int leaf, int subleaf) { + __cpuid_count(leaf, subleaf, out[0], out[1], out[2], out[3]); +} + +u64 xgetbv(unsigned int index) { + return _xgetbv(index); +} + +#endif + +bool bit_set(int value, int bit) { + return (value & (1 << bit)) != 0; +} +} // namespace + +CpuFeatures detect_cpu_features() { + CpuFeatures features{}; + + int regs[4]{}; + + cpuid(regs, 1, 0); + + const bool osxsave = bit_set(regs[2], 27); + const bool avx = bit_set(regs[2], 28); + + if (!osxsave || !avx) { + return features; + } + + const u64 xcr0 = xgetbv(0); + + const bool xmm_enabled = (xcr0 & 0x2) != 0; + const bool ymm_enabled = (xcr0 & 0x4) != 0; + const bool zmm_enabled = + (xcr0 & 0x20) != 0 && + (xcr0 & 0x40) != 0 && + (xcr0 & 0x80) != 0; + + if (!xmm_enabled || !ymm_enabled) { + return features; + } + + cpuid(regs, 7, 0); + + features.avx2 = bit_set(regs[1], 5); + features.bmi2 = bit_set(regs[1], 8); + + if (zmm_enabled) { + features.avx512f = bit_set(regs[1], 16); + features.avx512dq = bit_set(regs[1], 29); + features.avx512bw = bit_set(regs[1], 30); + features.avx512vl = bit_set(regs[1], 31); + features.avx512vbmi = bit_set(regs[2], 1); + } + + return features; +} +} // namespace pernix::internal +#else + +namespace pernix::internal { +CpuFeatures detect_cpu_features() { + return {}; +} +} // namespace pernix::internal + +#endif + diff --git a/src/dispatch/select.cpp b/src/dispatch/select.cpp new file mode 100644 index 0000000..ca4673e --- /dev/null +++ b/src/dispatch/select.cpp @@ -0,0 +1,720 @@ +#include +#include + +namespace pernix::internal { + Kernel select_compress_block_f32(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_block_f32(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_block_f32(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_block_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_block_f32(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_compress_blocks_f32(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_blocks_f32(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_blocks_f32(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_blocks_f32(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_blocks_f32(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_compress_block_f64(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_block_f64(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_block_f64(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_block_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_block_f64(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_compress_blocks_f64(Backend backend, u8 bit_width, u32 block_size) { + switch (backend) { + case Backend::Auto: + return select_auto_compress_blocks_f64(bit_width, block_size); + case Backend::Fallback: + return select_fallback_compress_blocks_f64(bit_width, block_size); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_compress_blocks_f64(bit_width, block_size); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_compress_blocks_f64(bit_width, block_size); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_decompress_block_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_block_f32(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_block_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_block_f32(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_decompress_blocks_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_blocks_f32(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_blocks_f32(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_decompress_block_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_block_f64(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_block_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_block_f64(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_decompress_blocks_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values) { + switch (backend) { + case Backend::Auto: + return select_auto_decompress_blocks_f64(bit_width, block_size, sign_values); + case Backend::Fallback: + return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + case Backend::X86Avx512Vbmi: + return select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_AVX2) + case Backend::X86Avx2: + return select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_X86_BMI2) + case Backend::X86Bmi2: + return select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_NEON) + case Backend::Arm64Neon: + return select_neon_decompress_blocks_f64(bit_width, block_size, sign_values); +#endif +#if defined(PERNIX_BUILD_ARM64_SVE2) + case Backend::Arm64Sve: { + if (get_cached_cpu_features().sve2) { + return select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values); + } + return {"invalid_backend", nullptr}; + } +#endif + default: + return {"invalid_backend", nullptr}; + } + } + + Kernel select_auto_compress_block_f32(u8 bit_width, u32 block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_compress_block_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_block_f32(bit_width, block_size); + } + + Kernel select_auto_compress_blocks_f32(u8 bit_width, u32 block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_compress_blocks_f32(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_blocks_f32(bit_width, block_size); + } + + Kernel select_auto_compress_block_f64(u8 bit_width, u32 block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_compress_block_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_block_f64(bit_width, block_size); + } + + Kernel select_auto_compress_blocks_f64(u8 bit_width, u32 block_size) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_compress_blocks_f64(bit_width, block_size)) { + return kernel; + } + } +#endif + + return select_fallback_compress_blocks_f64(bit_width, block_size); + } + + Kernel select_auto_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_decompress_block_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_block_f32(bit_width, block_size, sign_values); + } + + Kernel select_auto_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_decompress_blocks_f32(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_blocks_f32(bit_width, block_size, sign_values); + } + + Kernel select_auto_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_decompress_block_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_block_f64(bit_width, block_size, sign_values); + } + + Kernel select_auto_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values) { +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) || defined(PERNIX_BUILD_X86_AVX2) || defined(PERNIX_BUILD_X86_BMI2) || defined(PERNIX_BUILD_ARM64_NEON) || defined(PERNIX_BUILD_ARM64_SVE2) + const CpuFeatures features = get_cached_cpu_features(); +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + if ( + features.avx512f && + features.avx512dq && + features.avx512bw && + features.avx512vl && + features.avx512vbmi + ) { + if (auto kernel = select_avx512vbmi_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_AVX2) + if (features.avx2) { + if (auto kernel = select_avx2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + if (features.bmi2) { + if (auto kernel = select_bmi2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + if (features.neon) { + if (auto kernel = select_neon_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + if (features.sve2) { + if (auto kernel = select_sve2_decompress_blocks_f64(bit_width, block_size, sign_values)) { + return kernel; + } + } +#endif + + return select_fallback_decompress_blocks_f64(bit_width, block_size, sign_values); + } +} diff --git a/src/fallback/compression.cpp b/src/fallback/compression.cpp deleted file mode 100644 index 1ed56e8..0000000 --- a/src/fallback/compression.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return compress_block_fallback(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return compress_blocks_fallback(input, scale, output, blocks); - -int compress_block_fallback(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int compress_block_fallback_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int compress_blocks_fallback(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int compress_blocks_fallback_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus diff --git a/src/fallback/decompression.cpp b/src/fallback/decompression.cpp deleted file mode 100644 index e43b1df..0000000 --- a/src/fallback/decompression.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return decompress_block_fallback(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return decompress_blocks_fallback(input, scale, output, blocks); - -int decompress_block_fallback(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int decompress_block_fallback_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int decompress_blocks_fallback(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int decompress_blocks_fallback_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_DECOMPRESS_BLOCK_CASE -#undef PERNIX_DECOMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus \ No newline at end of file diff --git a/src/fallback/fallback_compression.cpp b/src/fallback/fallback_compression.cpp new file mode 100644 index 0000000..4f1e955 --- /dev/null +++ b/src/fallback/fallback_compression.cpp @@ -0,0 +1,194 @@ +#include +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("fallback", &compress_block_fallback) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("fallback", &compress_blocks_fallback) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("fallback", &compress_block_fallback) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("fallback", &compress_blocks_fallback) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } + + Kernel select_fallback_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"fallback", nullptr}; + } + } + + Kernel select_fallback_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"fallback", nullptr}; + } + } + + Kernel select_fallback_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"fallback", nullptr}; + } + } + + Kernel select_fallback_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"fallback", nullptr}; + } + } + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/fallback/fallback_decompression.cpp b/src/fallback/fallback_decompression.cpp new file mode 100644 index 0000000..184397f --- /dev/null +++ b/src/fallback/fallback_decompression.cpp @@ -0,0 +1,205 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_block_fallback); \ + return Kernel("fallback", &decompress_block_fallback) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_blocks_fallback); \ + return Kernel("fallback", &decompress_blocks_fallback) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_block_fallback); \ + return Kernel("fallback", &decompress_block_fallback) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("fallback", &decompress_blocks_fallback); \ + return Kernel("fallback", &decompress_blocks_fallback) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"fallback", nullptr}; \ + } \ + break + + Kernel select_fallback_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"fallback", nullptr}; + } + } + + Kernel select_fallback_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"fallback", nullptr}; + } + } + + Kernel select_fallback_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"fallback", nullptr}; + } + } + + Kernel select_fallback_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"fallback", nullptr}; + } + } + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/internal/pernix/arm64/neon/common.h b/src/internal/pernix/arm64/neon/common.h new file mode 100644 index 0000000..6cd9ad0 --- /dev/null +++ b/src/internal/pernix/arm64/neon/common.h @@ -0,0 +1,223 @@ +#ifndef PERNIX_ARM64_NEON_COMMON_H +#define PERNIX_ARM64_NEON_COMMON_H + +#include + +#include + +namespace pernix::arm64::neon::internal { + struct float64x2x8_t { + float64x2_t val[8]; + }; + + static constexpr u32 tail_bytes(const u8 bit_width, const u32 remaining_elements) { + const u32 tail_bits = remaining_elements * bit_width; + const u32 tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; + } + +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x4(const int8x16_t &input) { + const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); + const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); + + return { + { + vmovl_s16(vget_low_s16(s16_lo)), + vmovl_s16(vget_high_s16(s16_lo)), + vmovl_s16(vget_low_s16(s16_hi)), + vmovl_s16(vget_high_s16(s16_hi)), + } + }; + } + +__always_inline int32x4x2_t neon_convert_int16x8_int32x4x2(const int16x8_t &input) { + return { + { + vmovl_s16(vget_low_s16(input)), + vmovl_s16(vget_high_s16(input)), + } + }; + } + +__always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t &input, const float32x4_t &scale) { + return { + { + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[2]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[3]), scale), + } + }; + } + +__always_inline float32x4x2_t neon_dequantize_epi32(const int32x4x2_t &input, const float32x4_t &scale) { + return { + { + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + } + }; + } + +__always_inline float32x4_t neon_dequantize_epi32(const int32x4_t &input, const float32x4_t &scale) { + return vmulq_f32(vcvtq_f32_s32(input), scale); + } + +__always_inline float64x2_t neon_dequantize_epi32_f64(const int32x2_t &input, const float64x2_t &scale) { + return vmulq_f64(vcvtq_f64_s64(vmovl_s32(input)), scale); + } + +__always_inline float64x2x2_t neon_dequantize_epi32_f64(const int32x4_t &input, const float64x2_t &scale) { + return { + { + neon_dequantize_epi32_f64(vget_low_s32(input), scale), + neon_dequantize_epi32_f64(vget_high_s32(input), scale), + } + }; + } + +__always_inline float64x2x4_t neon_dequantize_epi32_f64(const int32x4x2_t &input, const float64x2_t &scale) { + const float64x2x2_t dequantized_low = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized_high = neon_dequantize_epi32_f64(input.val[1], scale); + + return { + { + dequantized_low.val[0], + dequantized_low.val[1], + dequantized_high.val[0], + dequantized_high.val[1], + } + }; + } + +__always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t &input, const float64x2_t &scale) { + const float64x2x2_t dequantized0 = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized1 = neon_dequantize_epi32_f64(input.val[1], scale); + const float64x2x2_t dequantized2 = neon_dequantize_epi32_f64(input.val[2], scale); + const float64x2x2_t dequantized3 = neon_dequantize_epi32_f64(input.val[3], scale); + + return { + { + dequantized0.val[0], + dequantized0.val[1], + dequantized1.val[0], + dequantized1.val[1], + dequantized2.val[0], + dequantized2.val[1], + dequantized3.val[0], + dequantized3.val[1], + } + }; + } + +__always_inline uint8x16_t neon_load_tail_elements_int8(const u8 *input, const u32 tail_bytes_count) { + u8 buffer[16] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u8(buffer); + } + +__always_inline uint16x8_t neon_load_tail_elements_int16(const u8 *input, const u32 tail_bytes_count) { + u16 buffer[8] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u16(buffer); + } + +__always_inline uint32x4_t neon_load_tail_elements_int32(const u8 *input, const u32 tail_bytes_count) { + u32 buffer[4] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u32(buffer); + } + +__always_inline float32x4_t neon_load_tail_elements_f32(const u8 *input, const u32 tail_elements) { + float32_t buffer[4] = {0.0f}; + std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); + return vld1q_f32(buffer); + } + +__always_inline float64x2_t neon_load_tail_elements_f64(const u8 *input, const u32 tail_elements) { + float64_t buffer[2] = {0.0}; + std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); + return vld1q_f64(buffer); + } + +__always_inline void neon_store_tail_elements_int8(u8 *output, const uint8x16x4_t &data, + const u32 tail_elements) { + u8 buffer[16 * 4]; + for (u32 i = 0; i < 4; ++i) { + vst1q_u8(buffer + i * 16, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(u8)); + } + +__always_inline void neon_store_tail_elements_int16(u16 *output, const uint16x8x4_t &data, + const u32 tail_elements) { + u16 buffer[8 * 4]; + for (u32 i = 0; i < 4; ++i) { + vst1q_u16(buffer + i * 8, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(u16)); + } + +__always_inline void neon_store_tail_elements_int32(u32 *output, const uint32x4x4_t &data, + const u32 tail_elements) { + u32 buffer[4 * 4]; + for (u32 i = 0; i < 4; ++i) { + vst1q_u32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(u32)); + } + +__always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4x4_t &data, + const u32 tail_elements) { + float32_t buffer[16 * 4]; + for (u32 i = 0; i < 4; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); + } + +__always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4x2_t &data, + const u32 tail_elements) { + float32_t buffer[8 * 2]; + for (u32 i = 0; i < 2; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); + } + +__always_inline void neon_store_tail_elements_f32(float32_t *output, const float32x4_t &data, + const u32 tail_elements) { + float32_t buffer[4]; + vst1q_f32(buffer, data); + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); + } + +__always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x4_t &data, + const u32 tail_elements) { + float64_t buffer[2 * 4]; + for (u32 i = 0; i < 4; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); + } + +__always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x2_t &data, + const u32 tail_elements) { + float64_t buffer[2 * 2]; + for (u32 i = 0; i < 2; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); + } + +__always_inline void neon_store_tail_elements_f64(float64_t *output, const float64x2x8_t &data, + const u32 tail_elements) { + float64_t buffer[2 * 8]; + for (u32 i = 0; i < 8; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); + } +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_COMMON_H diff --git a/src/internal/pernix/arm64/neon/compression.h b/src/internal/pernix/arm64/neon/compression.h new file mode 100644 index 0000000..83d46c3 --- /dev/null +++ b/src/internal/pernix/arm64/neon/compression.h @@ -0,0 +1,121 @@ +#ifndef PERNIX_ARM64_NEON_COMPRESSION_H +#define PERNIX_ARM64_NEON_COMPRESSION_H + +#include +#include + +#include +#include + +namespace pernix::arm64::neon { + namespace internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; + } + } // namespace internal + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int neon_compress_blocks(const u8 * __restrict__ input, const f32 scale, f32 * __restrict__ output, + const u32 blocks) { + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int neon_compress_blocks(const u8 * __restrict__ input, const f64 scale, f64 * __restrict__ output, + const u32 blocks) { + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; + } +} // namespace pernix::arm64::neon + +#endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/src/internal/pernix/arm64/neon/decompression.h b/src/internal/pernix/arm64/neon/decompression.h new file mode 100644 index 0000000..375adc2 --- /dev/null +++ b/src/internal/pernix/arm64/neon/decompression.h @@ -0,0 +1,376 @@ +#ifndef PERNIX_ARM64_NEON_DECOMPRESSION_H +#define PERNIX_ARM64_NEON_DECOMPRESSION_H + +#include +#include +#include + +#include +#include + +namespace pernix::arm64::neon { +namespace internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8 + (source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); + const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint8x16_t tail_source = neon_load_tail_elements_int8( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8 + (tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); + const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_8; ++i) { + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16 + (source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint16x8_t tail_source = neon_load_tail_elements_int16( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16 + (tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_4; ++i) { + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + unpacked = b128::neon_unpack_epi32_17to24(source); + } + } + + const float32x4_t dequantized = neon_dequantize_epi32(unpacked, scale_v); + + vst1q_f32(output, dequantized); + + output += 4; + } + + if constexpr (remaining_elements > 0) { + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi32_17to24( + tail_source); + } + + const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8 + (source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); + const float64x2x8_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 8; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint8x16_t tail_source = neon_load_tail_elements_int8( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8 + (tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); + const float64x2x8_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_8; ++i) { + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16 + (source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint16x8_t tail_source = neon_load_tail_elements_int16( + input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16 + (tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_4; ++i) { + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + unpacked = b128::neon_unpack_epi32_17to24(source); + } + } + + const float64x2x2_t dequantized = neon_dequantize_epi32_f64(unpacked, scale_v); + + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + } + + if constexpr (remaining_elements > 0) { + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi32_17to24( + tail_source); + } + + const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); + } + + return 0; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block(const void* __restrict__ input, const float_t scale, + void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(typed_input, scale, typed_output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block(const void* __restrict__ input, const double_t scale, + void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(typed_input, scale, typed_output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const void* __restrict__ input, const float_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + float_t* block_output = typed_output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const void* __restrict__ input, const double_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + double_t* block_output = typed_output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} +} // namespace pernix::arm64::neon + +#endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/src/internal/pernix/arm64/neon/packing.h b/src/internal/pernix/arm64/neon/packing.h new file mode 100644 index 0000000..538b5a8 --- /dev/null +++ b/src/internal/pernix/arm64/neon/packing.h @@ -0,0 +1,9 @@ +#ifndef PERNIX_ARM64_NEON_PACKING_H +#define PERNIX_ARM64_NEON_PACKING_H + +#include + +namespace pernix::arm64::neon::internal { +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_PACKING_H diff --git a/src/internal/pernix/arm64/neon/tables.h b/src/internal/pernix/arm64/neon/tables.h new file mode 100644 index 0000000..beb7c47 --- /dev/null +++ b/src/internal/pernix/arm64/neon/tables.h @@ -0,0 +1,214 @@ +#ifndef PERNIX_ARM64_NEON_TABLES_H +#define PERNIX_ARM64_NEON_TABLES_H + +#include +#include +#include +#include + +namespace pernix::arm64::neon::internal { + namespace detail { + inline constexpr std::size_t neon_vector_width = 128; + inline constexpr u8 inactive_lane = 0xff; + + template + constexpr bool table_indices_are_valid(const std::array &table) { + return std::ranges::all_of(table, [](const u8 index) { + return index == inactive_lane || index < Elements; + }); + } + + template + constexpr std::array make_primary_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t base = entry * lane_bytes; + + for (std::size_t lane_byte = 0; lane_byte < lane_bytes; ++lane_byte) { + table[base + lane_byte] = static_cast(first_byte + lane_byte); + } + } + + return table; + } + + template + constexpr std::array make_spill_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t base = entry * lane_bytes; + + if (bit_offset + BIT_WIDTH > LANE_BITS) { + table[base] = static_cast(first_byte + lane_bytes); + } + } + + return table; + } + + template + constexpr std::array make_shift_right() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + + table[entry] = -static_cast(bit_offset); + } + + return table; + } + + template + constexpr std::array make_shift_left_for_spill() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + const bool spills = bit_offset + BIT_WIDTH > LANE_BITS; + + table[entry] = spills ? static_cast(LANE_BITS - bit_offset) : 0; + } + + return table; + } + + template + constexpr std::array make_contiguous_permute_32() { + static_assert(ELEMENTS % 4 == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / 4; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + const std::size_t bit_end = bit_start + BIT_WIDTH - 1; + const std::size_t first_byte = bit_start / 8; + const std::size_t last_byte = bit_end / 8; + const std::size_t base = entry * 4; + + for (std::size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); + } + } + + return table; + } + + template + constexpr std::array make_shift_right_32() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + + table[entry] = -static_cast(bit_start % 8u); + } + + return table; + } + } // namespace detail + + template + struct table_unpacking; + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && VECTOR_WIDTH == detail::neon_vector_width) + struct table_unpacking { + private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 8; + + public: + static constexpr u8 bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 16); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); + }; + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && VECTOR_WIDTH == detail::neon_vector_width) + struct table_unpacking { + private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 16; + + public: + static constexpr u8 bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = + detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 8); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); + }; + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && VECTOR_WIDTH == detail::neon_vector_width && START_BIT_OFFSET < + 8) + struct table_unpacking { + private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 32; + + public: + static constexpr u8 bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute = + detail::make_contiguous_permute_32(); + alignas(64) static constexpr std::array shift = + detail::make_shift_right_32(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 4); + static_assert(detail::table_indices_are_valid(permute)); + }; + + template + struct table_packing; +} // namespace pernix::arm64::internal + +#endif // PERNIX_ARM64_NEON_TABLES_H diff --git a/src/internal/pernix/arm64/neon/unpacking.h b/src/internal/pernix/arm64/neon/unpacking.h new file mode 100644 index 0000000..cc9bcf0 --- /dev/null +++ b/src/internal/pernix/arm64/neon/unpacking.h @@ -0,0 +1,101 @@ +#ifndef PERNIX_ARM64_NEON_UNPACKING_H +#define PERNIX_ARM64_NEON_UNPACKING_H + +#include +#include + +using namespace pernix::arm64::neon::internal; + +namespace pernix::arm64::neon::internal::b128 { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t &input) { + if constexpr (BIT_WIDTH == 8) { + return vreinterpretq_s8_u8(input); + } else if constexpr (BIT_WIDTH == 1) { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + const uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + return vreinterpretq_s8_u8(vandq_u8(shifted, vdupq_n_u8(1))); + } else { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + + uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute2.data())); + + shifted = vorrq_u8(shifted, vshlq_u8(permuted2_u8, vld1q_s8(tables::shift2.data()))); + } + + constexpr int shift = 8 - BIT_WIDTH; + shifted = vshlq_n_u8(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s8(vreinterpretq_s8_u8(shifted), vdupq_n_s8(-shift)); + } else { + return vreinterpretq_s8_u8(vshlq_u8(shifted, vdupq_n_s8(-shift))); + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t &input) { + if constexpr (BIT_WIDTH == 16) { + return vreinterpretq_s16_u16(input); + } else { + using tables = table_unpacking; + + const uint8x16_t input_u8 = vreinterpretq_u8_u16(input); + + const uint8x16_t permuted1_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute1.data())); + + uint16x8_t shifted = vshlq_u16(vreinterpretq_u16_u8(permuted1_u8), vld1q_s16(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute2.data())); + + const uint16x8_t shifted2 = vshlq_u16(vreinterpretq_u16_u8(permuted2_u8), + vld1q_s16(tables::shift2.data())); + + shifted = vorrq_u16(shifted, shifted2); + } + + constexpr int shift = 16 - BIT_WIDTH; + shifted = vshlq_n_u16(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s16(vreinterpretq_s16_u16(shifted), vdupq_n_s16(-shift)); + } else { + return vreinterpretq_s16_u16(vshlq_u16(shifted, vdupq_n_s16(-shift))); + } + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t &input) { + using tables = table_unpacking; + + const uint8x16_t input_8 = vreinterpretq_u8_u32(input); + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input_8, vld1q_u8(tables::permute.data())); + + const uint32x4_t value = vshlq_u32(vreinterpretq_u32_u8(permuted_u8), vld1q_s32(tables::shift.data())); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); + } else { + constexpr u32 mask = (u32{1} << BIT_WIDTH) - 1u; + return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); + } + } +} // namespace pernix::arm64::neon::internal::b128 + +#endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/src/internal/pernix/arm64/sve2/compression.h b/src/internal/pernix/arm64/sve2/compression.h new file mode 100644 index 0000000..33f48fe --- /dev/null +++ b/src/internal/pernix/arm64/sve2/compression.h @@ -0,0 +1,48 @@ +#ifndef PERNIX_ARM64_SVE2_COMPRESSION_H +#define PERNIX_ARM64_SVE2_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { + namespace internal { + template + inline constexpr bool sve2_compression_unimplemented_v = false; + } // namespace internal + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_block(const f32 *, f32, u8 *) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_block(const f64 *, f64, u8 *) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_blocks(const f32 *, f32, u8 *, u32) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int sve2_compress_blocks(const f64 *, f64, u8 *, u32) { + static_assert(internal::sve2_compression_unimplemented_v, + "ARM64 SVE2 compression is not implemented yet"); + return -1; + } +} // namespace pernix + +#endif // PERNIX_ARM64_SVE2_COMPRESSION_H diff --git a/src/internal/pernix/arm64/sve2/decompression.h b/src/internal/pernix/arm64/sve2/decompression.h new file mode 100644 index 0000000..350a196 --- /dev/null +++ b/src/internal/pernix/arm64/sve2/decompression.h @@ -0,0 +1,452 @@ +#ifndef PERNIX_ARM64_SVE2_DECOMPRESSION_H +#define PERNIX_ARM64_SVE2_DECOMPRESSION_H + +#include +#include +#include + +#include +#include +#include +#include + +namespace pernix::arm64::sve2 { +namespace internal { +template +[[nodiscard]] __always_inline constexpr uint32_t packed_bytes(const uint32_t elements) { + return (elements * BIT_WIDTH + 7) / 8; +} + +[[nodiscard]] __always_inline svuint8_t sve2_load_packed_bytes(const uint8_t* __restrict__ input, + const uint32_t bytes) { + const svbool_t pg = svwhilelt_b8(uint64_t{0}, static_cast(bytes)); + return svld1_u8(pg, input); +} + +template +__always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloat32_t scale_v, + float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sb_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1ub_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double_t scale, + double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfloat32_t scale_v, + float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sh_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = + svld1uh_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const double_t scale, + double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfloat32_t scale_v, + float_t* __restrict__ output, + const uint32_t count) { + const svbool_t pg = svwhilelt_b32(uint64_t{0}, static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, values), scale_v); + } else { + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, svreinterpret_u32_s32(values)), scale_v); + } + + svst1_f32(pg, output, dequantized); +} + +template +__always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const double_t scale, + double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntw()); + + svst1_s32(svptrue_b32(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8 + (source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16 + (svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8 + (source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16 + (svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const void* __restrict__ input, const float_t scale, void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(typed_input, scale, typed_output); + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const void* __restrict__ input, const double_t scale, + void* __restrict__ output) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(typed_input, scale, typed_output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(typed_input, scale, typed_output); + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const void* __restrict__ input, const float_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + float_t* block_output = typed_output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const void* __restrict__ input, const double_t scale, void* __restrict__ output, + const uint32_t blocks) { + const auto* typed_input = static_cast(input); + auto* typed_output = static_cast(output); + const uint8_t* block_input = typed_input; + double_t* block_output = typed_output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} +} // namespace pernix::arm64::sve2 + +#endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/src/internal/pernix/arm64/sve2/packing.h b/src/internal/pernix/arm64/sve2/packing.h new file mode 100644 index 0000000..7d644f2 --- /dev/null +++ b/src/internal/pernix/arm64/sve2/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE2_PACKING_H +#define PERNIX_ARM64_SVE2_PACKING_H + +#include + +namespace pernix::arm64::sve2::internal { + template + inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_PACKING_H diff --git a/src/internal/pernix/arm64/sve2/tables.h b/src/internal/pernix/arm64/sve2/tables.h new file mode 100644 index 0000000..a11fcde --- /dev/null +++ b/src/internal/pernix/arm64/sve2/tables.h @@ -0,0 +1,117 @@ +#ifndef PERNIX_ARM64_SVE2_TABLES_H +#define PERNIX_ARM64_SVE2_TABLES_H + +#include + +namespace pernix::arm64::sve2::internal { +template +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svbool_t pg_b8() { return svptrue_b8(); } + + static svbool_t pg_b16() { return svptrue_b16(); } + + static svbool_t pg_b32() { return svptrue_b32(); } +}; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 1); + } + + static svuint8_t shift() { + const svbool_t pg = svptrue_b8(); + return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); + } + + static svuint8_t spill_shift() { + const svbool_t pg = svptrue_b8(); + return svsub_u8_x(pg, svdup_n_u8(8), shift()); + } +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); + const svuint8_t byte = svand_n_u8_x(pg, lane, 1); + + svuint8_t first; + if constexpr (BIT_WIDTH == 16) { + first = svlsl_n_u8_x(pg, elem, 1); + } else { + constexpr u8 extra_bits = BIT_WIDTH - 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); + first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); + } + + return svadd_u8_x(pg, first, byte); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 2); + } + + static svuint16_t shift() { + const svbool_t pg = svptrue_b16(); + return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); + } + + static svuint16_t spill_shift() { + const svbool_t pg = svptrue_b16(); + const svuint16_t bit_shift = shift(); + const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); + return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); + } +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) +struct table_unpacking { + static constexpr u8 bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); + const svuint8_t byte = svand_n_u8_x(pg, lane, 3); + + svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); + if constexpr (BIT_WIDTH % 8u != 0) { + constexpr u8 extra_bits = BIT_WIDTH % 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low_bits = + svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); + first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); + } + + return svadd_u8_x(pg, first, byte); + } + + static svuint32_t shift() { + const svbool_t pg = svptrue_b32(); + return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), + START_BIT_OFFSET), 7); + } +}; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/src/internal/pernix/arm64/sve2/unpacking.h b/src/internal/pernix/arm64/sve2/unpacking.h new file mode 100644 index 0000000..3d0825a --- /dev/null +++ b/src/internal/pernix/arm64/sve2/unpacking.h @@ -0,0 +1,94 @@ +#ifndef PERNIX_ARM64_SVE2_UNPACKING_H +#define PERNIX_ARM64_SVE2_UNPACKING_H + +#include + +#include "tables.h" + +namespace pernix::arm64::sve2::internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, + const svuint8_t spill_permute, const svuint8_t spill_shift) { + if constexpr (BIT_WIDTH == 8) { + return svreinterpret_s8(input); + } else { + const svbool_t pg = svptrue_b8(); + + const svuint8_t permuted = svtbl_u8(input, permute); + svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); + const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); + unpacked = svorr_u8_x(pg, unpacked, spill_shifted); + } + + if constexpr (BIT_WIDTH == 1) { + unpacked = svand_n_u8_x(pg, unpacked, 1); + return svreinterpret_s8(unpacked); + } else { + constexpr int sign_shift = 8 - BIT_WIDTH; + + unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); + } else { + return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); + } + } + } +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, + const svuint16_t shift, + const svuint8_t spill_permute, const svuint16_t spill_shift) { + if constexpr (BIT_WIDTH == 16) { + return svreinterpret_s16(input); + } else { + const svbool_t pg = svptrue_b16(); + + const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); + svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); + const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), + spill_shift); + shifted = svorr_u16_x(pg, shifted, spill_shifted); + } + + constexpr int sign_shift = 16 - BIT_WIDTH; + shifted = svlsl_n_u16_x(pg, shifted, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); + } else { + return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); + } + } +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { + using table = table_unpacking; + + const svbool_t pg = svptrue_b32(); + const svuint8_t permuted = svtbl_u8(input, table::permute()); + const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + } else { + constexpr u32 mask = (u32{1} << BIT_WIDTH) - 1u; + return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + } +} +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/src/internal/pernix/dispatch/cpu_features.h b/src/internal/pernix/dispatch/cpu_features.h new file mode 100644 index 0000000..bcf3e7d --- /dev/null +++ b/src/internal/pernix/dispatch/cpu_features.h @@ -0,0 +1,26 @@ +#ifndef PERNIX_CPU_FEATURES_H +#define PERNIX_CPU_FEATURES_H + +namespace pernix::internal { +struct CpuFeatures { + bool avx2 = false; + bool bmi2 = false; + bool avx512f = false; + bool avx512dq = false; + bool avx512bw = false; + bool avx512vl = false; + bool avx512vbmi = false; + bool neon = false; + bool sve = false; + bool sve2 = false; +}; + +CpuFeatures detect_cpu_features(); + +inline CpuFeatures get_cached_cpu_features() { + static const CpuFeatures features = detect_cpu_features(); + return features; +} +} + +#endif //PERNIX_CPU_FEATURES_H diff --git a/src/internal/pernix/dispatch/kernel.h b/src/internal/pernix/dispatch/kernel.h new file mode 100644 index 0000000..b7c0618 --- /dev/null +++ b/src/internal/pernix/dispatch/kernel.h @@ -0,0 +1,27 @@ +#ifndef PERNIX_KERNEL_H +#define PERNIX_KERNEL_H + +#include +#include + +namespace pernix::internal { +using KernelBlockF32Func = i32 (*)(const void*, f32, void*); +using KernelBlocksF32Func = i32 (*)(const void*, f32, void*, u32); +using KernelBlockF64Func = i32 (*)(const void*, f64, void*); +using KernelBlocksF64Func = i32 (*)(const void*, f64, void*, u32); + +template +struct Kernel { + std::string_view name; + FuncType func; + + explicit operator bool() const noexcept { + return func != nullptr; + } + + Kernel(const std::string_view name, FuncType func) : name(name), func(func) { + } +}; +} + +#endif //PERNIX_KERNEL_H diff --git a/src/internal/pernix/dispatch/select.h b/src/internal/pernix/dispatch/select.h new file mode 100644 index 0000000..d00357c --- /dev/null +++ b/src/internal/pernix/dispatch/select.h @@ -0,0 +1,171 @@ +#ifndef PERNIX_SELECT_H +#define PERNIX_SELECT_H + +#include +#include + +namespace pernix::internal { + Kernel select_compress_block_f32(Backend backend, u8 bit_width, u32 block_size); + + Kernel select_compress_blocks_f32(Backend backend, u8 bit_width, u32 block_size); + + Kernel select_compress_block_f64(Backend backend, u8 bit_width, u32 block_size); + + Kernel select_compress_blocks_f64(Backend backend, u8 bit_width, u32 block_size); + + Kernel select_decompress_block_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); + + Kernel select_decompress_blocks_f32(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); + + Kernel select_decompress_block_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); + + Kernel select_decompress_blocks_f64(Backend backend, u8 bit_width, u32 block_size, + bool sign_values); + + + Kernel select_auto_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_auto_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_auto_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_auto_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel select_auto_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_auto_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_auto_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_auto_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); + + + Kernel select_fallback_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_fallback_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_fallback_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_fallback_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel + select_fallback_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_fallback_decompress_blocks_f32(u8 bit_width, u32 block_size, + bool sign_values); + + Kernel + select_fallback_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_fallback_decompress_blocks_f64(u8 bit_width, u32 block_size, + bool sign_values); + +#if defined(PERNIX_BUILD_X86_AVX2) + + Kernel select_avx2_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_avx2_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_avx2_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_avx2_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel select_avx2_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_avx2_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_avx2_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_avx2_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_X86_BMI2) + + Kernel select_bmi2_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_bmi2_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_bmi2_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_bmi2_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel select_bmi2_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_bmi2_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_bmi2_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_bmi2_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_X86_AVX512_VBMI) + + Kernel select_avx512vbmi_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_avx512vbmi_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_avx512vbmi_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_avx512vbmi_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel select_avx512vbmi_decompress_block_f32(u8 bit_width, u32 block_size, + bool sign_values); + + Kernel select_avx512vbmi_decompress_blocks_f32(u8 bit_width, u32 block_size, + bool sign_values); + + Kernel select_avx512vbmi_decompress_block_f64(u8 bit_width, u32 block_size, + bool sign_values); + + Kernel select_avx512vbmi_decompress_blocks_f64(u8 bit_width, u32 block_size, + bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_ARM64_NEON) + + Kernel select_neon_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_neon_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_neon_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_neon_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel select_neon_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_neon_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_neon_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_neon_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); + +#endif + +#if defined(PERNIX_BUILD_ARM64_SVE2) + + Kernel select_sve2_compress_block_f32(u8 bit_width, u32 block_size); + + Kernel select_sve2_compress_blocks_f32(u8 bit_width, u32 block_size); + + Kernel select_sve2_compress_block_f64(u8 bit_width, u32 block_size); + + Kernel select_sve2_compress_blocks_f64(u8 bit_width, u32 block_size); + + Kernel select_sve2_decompress_block_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_sve2_decompress_blocks_f32(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_sve2_decompress_block_f64(u8 bit_width, u32 block_size, bool sign_values); + + Kernel select_sve2_decompress_blocks_f64(u8 bit_width, u32 block_size, bool sign_values); + +#endif +} + +#endif //PERNIX_SELECT_H diff --git a/src/internal/pernix/fallback/avx2_compression.h b/src/internal/pernix/fallback/avx2_compression.h new file mode 100644 index 0000000..e2c7200 --- /dev/null +++ b/src/internal/pernix/fallback/avx2_compression.h @@ -0,0 +1,266 @@ +#ifndef PERNIX_FALLBACK_COMPRESSION_H +#define PERNIX_FALLBACK_COMPRESSION_H + +#include + +#include +#include +#include +#include +#include +#include + +namespace pernix { + namespace internal { + /** + * @brief Quantize a single float value to i32 using the provided scale. + * + * @param input input float value to be quantized. + * @param scale scaling factor used during quantization. + * @return i32 quantized integer value. + */ +__always_inline i32 quantize_ps_epi32(const float input, const float scale) { + return static_cast(std::lroundf(input * scale)); + } + + /** + * @brief Quantize a single double value to i64 using the provided scale. + * + * @param input input double value to be quantized. + * @param scale scaling factor used during quantization. + * @return i64 quantized integer value. + */ +__always_inline i64 quantize_pd_epi64(const f64 input, const f64 scale) { + return std::llround(input * scale); + } + + /** + * @brief Quantize and clamp without narrowing through an out-of-range integer type. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_floating_point_v) +__always_inline i32 quantize_clamped(const T input, const T scale) { + constexpr i64 min_value = BIT_WIDTH == 1 ? 0 : -(i64{1} << (BIT_WIDTH - 1)); + constexpr i64 max_value = BIT_WIDTH == 1 ? 1 : ((i64{1} << (BIT_WIDTH - 1)) - 1); + + const long double scaled = static_cast(input) * static_cast(scale); + if (std::isnan(scaled)) { + return 0; + } + if (scaled <= static_cast(min_value)) { + return static_cast(min_value); + } + if (scaled >= static_cast(max_value)) { + return static_cast(max_value); + } + + return static_cast(std::llround(scaled)); + } + + /** + * @brief Clamp a signed quantized value to the representable range of BIT_WIDTH bits. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) +__always_inline i32 clamp_signed_quantized(const i64 value) { + if constexpr (BIT_WIDTH == 1) { + // 1-bit fallback is treated as binary quantization (0/1). + return static_cast(std::clamp(value, 0, 1)); + } + + constexpr i32 min_value = -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = (1 << (BIT_WIDTH - 1)) - 1; + return static_cast(std::clamp(value, min_value, max_value)); + } + + /** + * @brief Append packed scalar values into an output buffer using the selected + * storage width. + * + * @tparam T unsigned integer type used as the packing word. + * @tparam BIT_WIDTH bit width per value in the packed representation. + * @param input vector of quantized values to pack. + * @param bit_offset starting bit offset in the destination buffer. + * @param destination pointer to the output buffer. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) + void pack_epi32_fallback_inner(const std::vector &input, const u8 bit_offset, + u8 * __restrict__ destination) { + constexpr u32 bits_in_type = sizeof(T) * 8; + constexpr u32 bitmask = BIT_WIDTH == bits_in_type + ? std::numeric_limits::max() + : (1U << BIT_WIDTH) - 1U; + + std::size_t idx = 0; + std::size_t bits_in_buffer = bit_offset; + u64 buffer = bit_offset ? static_cast(destination[0] & ((1U << bit_offset) - 1U)) : 0; + +#pragma GCC unroll 64 + for (u32 raw_value: input) { + const u32 next_value = raw_value & bitmask; + + buffer |= static_cast(next_value) << bits_in_buffer; + bits_in_buffer += BIT_WIDTH; + + while (bits_in_buffer >= 8) { + destination[idx++] = static_cast(buffer & 0xFFU); + buffer >>= 8; + bits_in_buffer -= 8; + } + } + + if (bits_in_buffer > 0) { + destination[idx] = static_cast(buffer & 0xFFU); + } + } + + /** + * @brief Pack a vector of u32 values into a compact byte representation using fallback scalar implementation. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). + * + * @param input vector of u32 values to be packed. + * @param destination pointer to the output buffer where packed bytes will be stored. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + void pack_epi32_fallback(const std::vector &input, u8 * __restrict__ destination) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::pack_epi32_fallback_inner(input, 0, destination); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::pack_epi32_fallback_inner(input, 0, destination); + } else { + return internal::pack_epi32_fallback_inner(input, 0, destination); + } + } + } // namespace internal + + /** + * @brief Compress a single 512-bit block using fallback scalar implementation. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). + * + * @param input pointer to the start of the input float values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @return int status code (0 for success). + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_block_fallback(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + std::memset(output, 0, BLOCK_SIZE); + + std::vector block_values(elements_per_block); +#pragma GCC unroll 64 + for (u32 i = 0; i < elements_per_block; i++) { + const i32 quantized = internal::quantize_clamped(input[i], scale); + block_values[i] = static_cast(quantized); + } + + internal::pack_epi32_fallback(block_values, output); + return 0; + } + + /** + * @brief Compress a single block of double values using the fallback scalar implementation. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). + * @param input pointer to the start of the input double values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @return int status code (0 for success). + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_block_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + std::memset(output, 0, BLOCK_SIZE); + + std::vector block_values(elements_per_block); +#pragma GCC unroll 32 + for (u32 i = 0; i < elements_per_block; i++) { + const i32 quantized = internal::quantize_clamped(input[i], scale); + block_values[i] = static_cast(quantized); + } + + internal::pack_epi32_fallback(block_values, output); + return 0; + } + + /** + * @brief Compress multiple 512-bit blocks using fallback scalar implementation. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). + * + * @param input pointer to the start of the input float values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @param blocks number of 512-bit blocks to compress. + * @return int status code (0 for success). + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_blocks_fallback(const void * __restrict__ input_ptr, float scale, void * __restrict__ output_ptr, + u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + compress_block_fallback(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } + + /** + * @brief Compress multiple blocks of double values using the fallback scalar implementation. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam BLOCK_SIZE size of each block in bytes (default 64 for 512 bits). + * @param input pointer to the start of the input double values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @param blocks number of blocks to compress. + * @return int status code (0 for success). + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int compress_blocks_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const unsigned int blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + compress_block_fallback(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + return 0; + } +} // namespace pernix +#endif // PERNIX_FALLBACK_COMPRESSION_H diff --git a/src/internal/pernix/fallback/avx2_decompression.h b/src/internal/pernix/fallback/avx2_decompression.h new file mode 100644 index 0000000..24431e5 --- /dev/null +++ b/src/internal/pernix/fallback/avx2_decompression.h @@ -0,0 +1,248 @@ +#ifndef PERNIX_FALLBACK_DECOMPRESSION_H +#define PERNIX_FALLBACK_DECOMPRESSION_H + +#include + +#include +#include +#include +#include + +namespace pernix { + namespace internal { + /** +* @brief Dequantize a single i32 value to float using the provided scale. +* +* @param input input i32 value to be dequantized. +* @param scale scaling factor used during quantization. +* @return float dequantized float value. +*/ +__always_inline float dequantize_epi32(const i32 input, const float scale) { + return static_cast(input) * scale; + } + + /** +* @brief Dequantize a single i64 value to double using the provided scale. +* +* @param input input i64 value to be dequantized. +* @param scale scaling factor used during quantization. +* @return f64 dequantized double value. +*/ +__always_inline f64 dequantize_epi64(const i64 input, const f64 scale) { + return static_cast(input) * scale; + } + + /** +* @brief Sign-extend a packed integer value stored in the low bits of a 32-bit word. +* +* @tparam BIT_WIDTH number of significant bits in the encoded value. +* @param value unsigned packed value. +* @return i32 sign-extended value. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __always_inline auto sign_extend(const u32 value) -> i32 { + if constexpr (BIT_WIDTH == 1) { + return static_cast(value & 1U); + } + + constexpr u32 sign_bit = u32{1} << (BIT_WIDTH - 1); + constexpr u32 mask = (u32{1} << BIT_WIDTH) - 1; + const u32 masked = value & mask; + return static_cast((static_cast(masked ^ sign_bit)) - static_cast(sign_bit)); + } + + /** +* @brief Unpack bit-packed values from a typed input span into signed 32-bit integers. +* +* @tparam T unsigned integer type used to read the source buffer. +* @tparam BIT_WIDTH bit width per packed value. +* @tparam SIGN_VALUES whether to sign-extend unpacked values. +* @param input pointer to the typed packed input buffer. +* @param bit_offset starting bit offset in the first input word. +* @param elements number of values to unpack. +* @return std::vector unpacked values. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24 && std::is_integral_v && std::is_unsigned_v) + __always_inline auto unpack_epi32_fallback_inner(const u8 * __restrict__ input, const u8 bit_offset, + const std::size_t elements) + -> std::vector { + constexpr u32 bits_in_type = sizeof(T) * 8; + constexpr u32 bitmask = BIT_WIDTH == bits_in_type + ? std::numeric_limits::max() + : (1U << BIT_WIDTH) - 1U; + + std::vector output(elements); + + std::size_t idx = 0; + u8 bits_in_buffer = 8 - bit_offset; + u64 buffer = static_cast(input[idx++]) >> bit_offset; + +#pragma GCC unroll 64 + for (u32 i = 0; i < elements; i++) { + while (BIT_WIDTH > bits_in_buffer) { + const auto next_value = static_cast(input[idx++]) << bits_in_buffer; + buffer |= next_value; + bits_in_buffer += 8; + } + + const u32 raw_value = static_cast(buffer & bitmask); + if constexpr (SIGN_VALUES) { + output[i] = sign_extend(raw_value); + } else { + output[i] = static_cast(raw_value); + } + + buffer >>= BIT_WIDTH; + bits_in_buffer -= BIT_WIDTH; + } + + return output; + } + + /** +* @brief Unpack packed i32 values from the input buffer using fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the packed data. +* @param elements number of elements to unpack. +* @return std::vector unpacked i32 values. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __always_inline auto unpack_epi32_fallback(const u8 * __restrict__ input, + const std::size_t elements) -> std::vector { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return unpack_epi32_fallback_inner(input, 0, elements); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return unpack_epi32_fallback_inner(input, 0, elements); + } else { + return unpack_epi32_fallback_inner(input, 0, elements); + } + } + } // namespace internal + + /** +* @brief Decompress a single 512\-bit block using fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed block. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed float values will be stored. +* @return int status code (0 for success). +*/ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + int decompress_block_fallback(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const std::vector block_values = internal::unpack_epi32_fallback( + input, elements_per_block); + +#pragma GCC unroll 512 + for (u32 i = 0; i < elements_per_block; i++) { + output[i] = internal::dequantize_epi32(block_values[i], scale); + } + + return 0; + } + + /** +* @brief Decompress a single block to double values using the fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed block. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed double values will be stored. +* @return int status code (0 for success). +*/ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + int decompress_block_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const std::vector block_values = internal::unpack_epi32_fallback( + input, elements_per_block); + +#pragma GCC unroll 512 + for (u32 i = 0; i < elements_per_block; i++) { + output[i] = internal::dequantize_epi64(block_values[i], scale); + } + + return 0; + } + + /** +* @brief Decompress multiple 512\-bit blocks using fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed data. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed float values will be stored. +* @param blocks number of 512-bit blocks to decompress. +* @return int status code (0 for success). +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + int decompress_blocks_fallback(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + decompress_block_fallback(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } + + /** +* @brief Decompress multiple blocks to double values using the fallback scalar implementation. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam SIGN_VALUES whether the values are signed or unsigned. +* @param input pointer to the start of the compressed data. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where decompressed double values will be stored. +* @param blocks number of blocks to decompress. +* @return int status code (0 for success). +*/ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + int decompress_blocks_fallback(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + decompress_block_fallback(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } +} // namespace pernix + +#endif // PERNIX_FALLBACK_DECOMPRESSION_H diff --git a/src/internal/pernix/simd_compat.h b/src/internal/pernix/simd_compat.h new file mode 100644 index 0000000..d55aafa --- /dev/null +++ b/src/internal/pernix/simd_compat.h @@ -0,0 +1,36 @@ +#ifndef PERNIX_SIMD_COMPAT_H +#define PERNIX_SIMD_COMPAT_H + +#include +#include +#include + +#if defined(PERNIX_USE_SIMDE) +#define SIMDE_ENABLE_NATIVE_ALIASES +#undef SIMDE_X86_AVX512FP16_NATIVE +#if defined(__clang__) +#define SIMDE_X86_AVX512BF16_NATIVE +#endif +// #define SIMDE_NO_NATIVE +#if defined(PERNIX_BACKEND_X86) +#include +#include +#include +#elif defined(PERNIX_BACKEND_ARM64_NEON) +#include +#elif defined(PERNIX_BACKEND_ARM64_SVE) || defined(PERNIX_BACKEND_ARM64_SVE2) +#include +#endif + +#elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#include +#elif defined(__aarch64__) || defined(__arm64ec__) +#ifdef __ARM_FEATURE_SVE +#include +#endif +#ifdef __ARM_NEON +#include +#endif +#endif + +#endif // PERNIX_SIMD_COMPAT_H diff --git a/src/internal/pernix/x86/avx2/avx2_compression.h b/src/internal/pernix/x86/avx2/avx2_compression.h new file mode 100644 index 0000000..01c9626 --- /dev/null +++ b/src/internal/pernix/x86/avx2/avx2_compression.h @@ -0,0 +1,607 @@ +#ifndef PERNIX_AVX2_COMPRESSION_H +#define PERNIX_AVX2_COMPRESSION_H + +#include +#include +#include + +#include +#include +#include +#include + +namespace pernix { + namespace internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) +__always_inline __m256i mm256_clamp_signed_epi32(__m256i input) { + constexpr i32 min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); + return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), + _mm256_set1_epi32(max_value)); + } + + /** +* @brief Quantize four float values into signed 32-bit integers. +* +* @param input source float lane values. +* @param scale per-lane scale factor. +* @return __m128i quantized values. +*/ +__always_inline __m128i mm_quantize_ps_epi32(const __m128 &input, const __m128 &scale) { + const __m128 scaled = _mm_mul_ps(input, scale); + // const __m128 rounded = _mm_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return _mm_cvtps_epi32(scaled); + } + + /** +* @brief Quantize two double values into a partially filled 128-bit integer register. +* +* @param input source double lane values. +* @param scale per-lane scale factor. +* @return __m128i quantized values in the low lanes. +*/ +__always_inline __m128i mm_quantize_pd_epi32(const __m128d &input, const __m128d &scale) { + const __m128d scaled = _mm_mul_pd(input, scale); + return _mm_cvtpd_epi32(scaled); + } + + /** +* @brief Quantize eight float values into signed 32-bit integers. +* +* @param input source float lane values. +* @param scale per-lane scale factor. +* @return __m256i quantized values. +*/ +__always_inline __m256i mm256_quantize_ps_epi32(const __m256 &input, const __m256 &scale) { + const __m256 scaled = _mm256_mul_ps(input, scale); + // const __m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return _mm256_cvtps_epi32(scaled); + } + + /** +* @brief Quantize four double values into signed 32-bit integers. +* +* @param input source double lane values. +* @param scale per-lane scale factor. +* @return __m128i quantized values in the low lanes. +*/ +__always_inline __m128i mm256_quantize_pd_epi32(const __m256d &input, const __m256d &scale) { + const __m256d scaled = _mm256_mul_pd(input, scale); + return _mm256_cvtpd_epi32(scaled); + } + +#ifndef PERNIX_USE_SIMDE + /** +* @brief Emulate per-16-bit left shifts on AVX2. +* +* @param a source values. +* @param count per-lane shift amounts. +* @return __m128i shifted values. +*/ +__always_inline static __m128i _mm_sllv_epi16(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi32(0xffff0000); + const __m128i low_half = _mm_sllv_epi32(a, _mm_andnot_si128(mask, count)); + const __m128i high_half = _mm_sllv_epi32(_mm_and_si128(mask, a), _mm_srli_epi32(count, 16)); + return _mm_blend_epi16(low_half, high_half, 0xaa); + } + + /** +* @brief Emulate per-16-bit right shifts on AVX2. +*/ +__always_inline static __m128i _mm_srlv_epi16(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi32(0x0000ffff); + const __m128i low_half = _mm_srlv_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); + const __m128i high_half = _mm_srlv_epi32(a, _mm_srli_epi32(count, 16)); + return _mm_blend_epi16(low_half, high_half, 0xaa); + } + + /** +* @brief Emulate per-16-bit left shifts on 256-bit AVX2 vectors. +*/ +__always_inline static __m256i _mm256_sllv_epi16(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi32(0xffff0000); + const __m256i low_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); + const __m256i high_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(low_half, high_half, 0xaa); + } + + /** +* @brief Emulate per-16-bit right shifts on 256-bit AVX2 vectors. +*/ +__always_inline static __m256i _mm256_srlv_epi16(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi32(0x0000ffff); + const __m256i low_half = _mm256_srlv_epi32(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); + const __m256i high_half = _mm256_srlv_epi32(a, _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(low_half, high_half, 0xaa); + } + + /** +* @brief Blend 8-bit lanes by expanding a scalar mask value. +*/ +__always_inline static __m128i mm_blend_epi8(const __m128i X, const __m128i Y, const i8 M) { + return _mm_blendv_epi8(X, Y, _mm_set1_epi8(M)); + } + + /** +* @brief Blend 8-bit lanes in 256-bit vectors by expanding a scalar mask value. +*/ +__always_inline static __m256i mm256_blend_epi8(const __m256i X, const __m256i Y, const i8 M) { + return _mm256_blendv_epi8(X, Y, _mm256_set1_epi8(M)); + } + + /** +* @brief Emulate per-byte left shifts on 128-bit vectors. +*/ +__always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi16(0xff00); + const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); + const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); + return mm_blend_epi8(low_half, high_half, 0xaa); + } + + /** +* @brief Emulate per-byte right shifts on 128-bit vectors. +*/ +__always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi16(0x00ff); + const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); + const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); + return mm_blend_epi8(low_half, high_half, 0xaa); + } + + /** +* @brief Emulate per-byte left shifts on 256-bit vectors. +*/ +__always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi16(0xff00); + const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); + const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); + return mm256_blend_epi8(low_half, high_half, 0xaa); + } + + /** +* @brief Emulate per-byte right shifts on 256-bit vectors. +*/ +__always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi16(0x00ff); + const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); + const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); + return mm256_blend_epi8(low_half, high_half, 0xaa); + } +#endif + + /** +* @brief Pack four 32-bit values for bit widths 1 through 3. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) + __always_inline auto mm_pack_epi32_avx2_1to3(__m128i &input) -> __m128i { + constexpr u32 bitmask = (1U << BIT_WIDTH) - 1U; + const __m128i masked = _mm_and_si128(input, _mm_set1_epi32(static_cast(bitmask))); + + alignas(16) u32 lanes[4]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(lanes), masked); + + const u32 packed = (lanes[0] & bitmask) | ((lanes[1] & bitmask) << BIT_WIDTH) | ( + (lanes[2] & bitmask) << (2 * BIT_WIDTH)) | + ((lanes[3] & bitmask) << (3 * BIT_WIDTH)); + + return _mm_cvtsi32_si128(static_cast(packed)); + } + + /** +* @brief Pack eight 32-bit values for bit widths 1 through 3. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 3) +__always_inline __m256i mm256_pack_epi32_avx2_1to3(const __m256i &input) { + constexpr u32 bitmask = (1u << BIT_WIDTH) - 1u; + + const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(static_cast(bitmask))); + + const __m256i shifts = _mm256_setr_epi32(0 * BIT_WIDTH, 1 * BIT_WIDTH, 2 * BIT_WIDTH, 3 * BIT_WIDTH, + 4 * BIT_WIDTH, 5 * BIT_WIDTH, + 6 * BIT_WIDTH, 7 * BIT_WIDTH); + + const __m256i shifted = _mm256_sllv_epi32(masked, shifts); + + __m128i x = _mm_or_si128(_mm256_castsi256_si128(shifted), _mm256_extracti128_si256(shifted, 1)); + + x = _mm_or_si128(x, _mm_srli_si128(x, 8)); + x = _mm_or_si128(x, _mm_srli_si128(x, 4)); + + return _mm256_castsi128_si256(x); + } + +__always_inline __m256i mm256_pack_epi32_avx2_4(const __m256i &input) { + const __m256i zero = _mm256_setzero_si256(); + + const __m256i packed16 = _mm256_packus_epi32(input, zero); + const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i packed8 = _mm256_packus_epi16(permuted, zero); + + const __m256i combined = _mm256_or_si256(packed8, _mm256_srli_epi16(packed8, 4)); + + const __m256i shuffled = _mm256_shuffle_epi8(combined, _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 2, + 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1)); + + return shuffled; + } + + /** +* @brief Pack four 32-bit values for bit widths 9 through 16. +*/ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) + __always_inline auto mm_pack_epi32_avx2_9to16(__m128i &input) -> __m128i { + using tables = pack_tables_avx2_16; + constexpr u16 bitmask = (1 << BIT_WIDTH) - 1; + const __m128i masked = _mm_and_si128(input, _mm_set1_epi16(bitmask)); + __m128i combined; + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); + + combined = _mm_or_si128(shifted1, shifted2); + } else { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); + + combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + return combined; + } + + /** +* @brief Pack eight 32-bit values for bit widths 9 through 16. +*/ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) + __always_inline auto mm256_pack_epi32_avx2_9to16(const __m256i &input) -> __m256i { + using tables = pack_tables_avx2_16; + constexpr u16 bitmask = (1 << BIT_WIDTH) - 1; + const __m128i packed = _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1)); + const __m128i masked = _mm_and_si128(packed, _mm_set1_epi16(bitmask)); + __m128i combined; + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15 || BIT_WIDTH == 16) { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_srlv_epi16(shuffled2, tables::get_shift2()); + + combined = _mm_or_si128(shifted1, shifted2); + } else { + const __m128i shuffled1 = _mm_shuffle_epi8(masked, tables::get_permute1()); + const __m128i shuffled2 = _mm_shuffle_epi8(masked, tables::get_permute2()); + const __m128i shuffled3 = _mm_shuffle_epi8(masked, tables::get_permute3()); + + const __m128i shifted1 = _mm_sllv_epi16(shuffled1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi16(shuffled2, tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi16(shuffled3, tables::get_shift3()); + + combined = _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + return _mm256_castsi128_si256(combined); + } + + template + requires(BIT_WIDTH >= 5 && BIT_WIDTH <= 7) + __always_inline auto mm256_pack_epi32_avx2_5to7(const __m256i &input) -> __m256i { + const __m256i zero = _mm256_setzero_si256(); + + const __m256i packed16 = _mm256_packus_epi32(input, zero); + const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i packed8 = _mm256_packus_epi16(permuted, zero); + + const __m256i even = _mm256_and_si256(packed8, _mm256_set1_epi16(0x00FF)); + const __m256i odd = _mm256_and_si256(packed8, _mm256_set1_epi16(0xFF00)); + + const __m256i pair16 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); + const __m256i extended = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(pair16)); + + return mm256_pack_epi32_avx2_9to16<2 * BIT_WIDTH>(extended); + } + + /** +* @brief Pack eight 32-bit values for bit widths 17 through 24. +*/ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) + __always_inline auto mm256_pack_epi32_avx2_17to24(const __m256i &input) -> __m256i { + using tables = pack_tables_avx2_24; + constexpr u32 bitmask = (1 << BIT_WIDTH) - 1; + const __m256i masked = _mm256_and_si256(input, _mm256_set1_epi32(bitmask)); + __m256i combined; + + if constexpr (BIT_WIDTH == 24) { + const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); + const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); + + const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); + const __m256i shifted2 = _mm256_srlv_epi32(shuffled2, tables::get_shift2()); + + combined = _mm256_or_si256(shifted1, shifted2); + } else { + const __m256i shuffled1 = _mm256_permutevar8x32_epi32(masked, tables::get_permute1()); + const __m256i shuffled2 = _mm256_permutevar8x32_epi32(masked, tables::get_permute2()); + const __m256i shuffled3 = _mm256_permutevar8x32_epi32(masked, tables::get_permute3()); + + const __m256i shifted1 = _mm256_sllv_epi32(shuffled1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi32(shuffled2, tables::get_shift2()); + const __m256i shifted3 = _mm256_srlv_epi32(shuffled3, tables::get_shift3()); + + combined = _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); + } + + return combined; + } + + /** +* @brief Pack aligned 8-bit or 16-bit values from four 32-bit lanes. +*/ + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + auto mm_pack_aligned_epi32_avx2(__m128i &input) -> __m128i { + if constexpr (BIT_WIDTH == 8) { + return _mm_packus_epi16(_mm_packs_epi32(input, _mm_setzero_si128()), _mm_setzero_si128()); + } else { + return _mm_packs_epi32(input, _mm_setzero_si128()); + } + } + + /** +* @brief Dispatch to the appropriate 128-bit AVX2 packer for the selected bit width. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 16) + auto mm_pack_epi32_avx2(__m128i &input) -> __m128i { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { + return internal::mm_pack_epi32_avx2_1to3(input); + } else if constexpr (BIT_WIDTH >= 4 && BIT_WIDTH <= 8) { + // TODO: implementation for 4-8 bits + return _mm_setzero_si128(); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::mm_pack_epi32_avx2_9to16(input); + } else { + return _mm_setzero_si128(); + } + } + + /** +* @brief Pack aligned 8-bit or 16-bit values from eight 32-bit lanes. +*/ + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + __m256i mm256_pack_aligned_epi32_avx2(const __m256i &input) { + if constexpr (BIT_WIDTH == 8) { + const __m128i packed16 = _mm_packs_epi32(_mm256_castsi256_si128(input), + _mm256_extracti128_si256(input, 1)); + const __m128i packed8 = _mm_packs_epi16(packed16, _mm_setzero_si128()); + return _mm256_castsi128_si256(packed8); + } else { + return _mm256_castsi128_si256( + _mm_packs_epi32(_mm256_castsi256_si128(input), _mm256_extracti128_si256(input, 1))); + } + } + + /** +* @brief Dispatch to the appropriate 256-bit AVX2 packer for the selected bit width. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m256i mm256_pack_epi32_avx2(const __m256i &input) { + if constexpr (BIT_WIDTH == 8 || BIT_WIDTH == 16) { + return internal::mm256_pack_aligned_epi32_avx2(input); + } else { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 3) { + return internal::mm256_pack_epi32_avx2_1to3(input); + } else if constexpr (BIT_WIDTH == 4) { + return mm256_pack_epi32_avx2_4(input); + } else if constexpr (BIT_WIDTH >= 5 && BIT_WIDTH <= 7) { + return internal::mm256_pack_epi32_avx2_5to7(input); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 15) { + return internal::mm256_pack_epi32_avx2_9to16(input); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::mm256_pack_epi32_avx2_17to24(input); + } + } + return _mm256_setzero_si256(); + } + } // namespace internal + + /** +* @brief Compress a single block of float using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + std::memset(output, 0, BLOCK_SIZE); + + const __m256 scale_v = _mm256_set1_ps(scale); +#pragma GCC unroll 8 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); + const __m256i packed_input = internal::mm256_clamp_signed_epi32(quantized); + const __m256i packed = internal::mm256_pack_epi32_avx2(packed_input); + std::memcpy(output, &packed, BIT_WIDTH); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining) { + std::vector block_values(remaining); +#pragma GCC unroll 8 + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_ps_epi32(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); + } + + return 0; + } + + /** +* @brief Compress a single block of double using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + std::memset(output, 0, BLOCK_SIZE); + + const __m256d scale_v = _mm256_set1_pd(scale); +#pragma GCC unroll 8 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256d source1 = _mm256_loadu_pd(input); + const __m256d source2 = _mm256_loadu_pd(input + 4); + const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); + const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); + __m256i combined = _mm256_castsi128_si256(quantized1); + combined = _mm256_inserti128_si256(combined, quantized2, 1); + const __m256i packed = internal::mm256_pack_epi32_avx2( + internal::mm256_clamp_signed_epi32(combined)); + // _mm_storeu_si128(reinterpret_cast<__m128i*>(output), _mm256_castsi256_si128(packed)); + std::memcpy(output, &packed, BIT_WIDTH); + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining) { + std::vector block_values(remaining); +#pragma GCC unroll 8 + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_pd_epi64(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); + } + + return 0; + } + + /** +* @brief Compress multiple blocks using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_avx2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } + + /** +* @brief Compress multiple blocks using AVX2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). +* +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_avx2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } +} // namespace pernix + +#endif // PERNIX_AVX2_COMPRESSION_H diff --git a/src/internal/pernix/x86/avx2/avx2_decompression.h b/src/internal/pernix/x86/avx2/avx2_decompression.h new file mode 100644 index 0000000..2f50018 --- /dev/null +++ b/src/internal/pernix/x86/avx2/avx2_decompression.h @@ -0,0 +1,349 @@ +#ifndef PERNIX_AVX2_DECOMPRESSION_H +#define PERNIX_AVX2_DECOMPRESSION_H + +#include +#include +#include + +#include +#include +#include +#include + +namespace pernix { + namespace internal { + /** + * @brief Convert an 8-lane mask to the lane representation used by AVX2 float masked stores. + */ +__always_inline __m256i mm256_convert_vmask_epi32(const __mmask8 mask8) { + return _mm256_setr_epi32((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, + (mask8 & 0x8) ? -1 : 0, + (mask8 & 0x10) ? -1 : 0, (mask8 & 0x20) ? -1 : 0, (mask8 & 0x40) ? -1 : 0, + (mask8 & 0x80) ? -1 : 0); + } + + /** + * @brief Convert a 4-lane mask to the lane representation used by AVX2 double masked stores. + */ +__always_inline __m256i mm256_convert_vmask_epi64(const __mmask8 mask8) { + return _mm256_setr_epi64x((mask8 & 0x1) ? -1 : 0, (mask8 & 0x2) ? -1 : 0, (mask8 & 0x4) ? -1 : 0, + (mask8 & 0x8) ? -1 : 0); + } + + /** + * @brief Dequantize four 32-bit integers to floats. + */ +__always_inline __m128 mm_dequantize_epi32(const __m128i &input, const __m128 &scale) { + const __m128 converted = _mm_cvtepi32_ps(input); + return _mm_mul_ps(converted, scale); + } + + /* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ +__always_inline __m128d convert_epi64_pd(const __m128i v) { + __m128i xH = _mm_srai_epi32(v, 16); + xH = _mm_blend_epi16(xH, _mm_setzero_si128(), 0x33); + xH = _mm_add_epi64(xH, _mm_castpd_si128(_mm_set1_pd(442721857769029238784.))); // 3*2^67 + const __m128i xL = _mm_blend_epi16(v, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0x88); // 2^52 + const __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(442726361368656609280.)); // 3*2^67 + 2^52 + return _mm_add_pd(f, _mm_castsi128_pd(xL)); + } + + /** + * @brief Dequantize two 64-bit integers to doubles. + */ +__always_inline __m128d mm_dequantize_epi64_pd(const __m128i &input, const __m128d &scale) { + const __m128d converted = convert_epi64_pd(input); + return _mm_mul_pd(converted, scale); + } + + /** + * @brief Dequantize eight 32-bit integers to floats. + */ +__always_inline __m256 mm256_dequantize_epi32(const __m256i &input, const __m256 &scale) { + const __m256 converted = _mm256_cvtepi32_ps(input); + return _mm256_mul_ps(converted, scale); + } + + /* https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx */ +__always_inline __m256d convert_epi64_pd(__m256i v) { + __m256i xH = _mm256_srai_epi32(v, 16); + xH = _mm256_blend_epi16(xH, _mm256_setzero_si256(), 0x33); + xH = _mm256_add_epi64(xH, _mm256_castpd_si256(_mm256_set1_pd(442721857769029238784.))); + const __m256i xL = _mm256_blend_epi16(v, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0x88); + const __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(442726361368656609280.)); + return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); + } + + /** + * @brief Dequantize four 64-bit integers to doubles. + */ +__always_inline __m256d mm256_dequantize_epi64_pd(const __m256i &input, const __m256d &scale) { + const __m256d converted = convert_epi64_pd(input); + return _mm256_mul_pd(converted, scale); + } + + /** + * @brief Unpack four aligned 8-bit or 16-bit values directly from the input buffer. + */ + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + __m128i mm_unpack_aligned_epi32_avx2(const u8 * __restrict__ input) { + if constexpr (BIT_WIDTH == 8) { + const __m128i source = _mm_loadu_si32(input); + if constexpr (SIGN_VALUES) { + return _mm_cvtepi8_epi32(source); + } else { + return _mm_cvtepu8_epi32(source); + } + } else { + const __m128i source = _mm_loadu_si64(input); + if constexpr (SIGN_VALUES) { + return _mm_cvtepi16_epi32(source); + } else { + return _mm_cvtepu16_epi32(source); + } + } + } + + /** + * @brief Unpack four values using the table-driven AVX2 shuffle path. + */ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + __m128i mm_unpack_epi32_avx2(const u8 * __restrict__ input) { + using unpack_table = unpack_tables_avx2; + constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; + + __m128i source = _mm_setzero_si128(); + std::memcpy(&source, input, packed_bytes); + + const __m128i shuffled = _mm_shuffle_epi8(source, unpack_table::get_shuffle()); + + constexpr u16 shift = 32 - BIT_WIDTH; + __m128i shifted = _mm_sllv_epi32(shuffled, unpack_table::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm_srai_epi32(shifted, shift); + } else { + shifted = _mm_srli_epi32(shifted, shift); + } + + return shifted; + } + + /** + * @brief Unpack eight aligned 8-bit or 16-bit values directly from the input buffer. + */ + template + requires(BIT_WIDTH == 8 || BIT_WIDTH == 16) + __m256i mm256_unpack_aligned_epi32_avx2(const u8 * __restrict__ input) { + if constexpr (BIT_WIDTH == 8) { + const __m128i source = _mm_loadu_si64(input); + if constexpr (SIGN_VALUES) { + return _mm256_cvtepi8_epi32(source); + } else { + return _mm256_cvtepu8_epi32(source); + } + } else { + const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); + if constexpr (SIGN_VALUES) { + return _mm256_cvtepi16_epi32(source); + } else { + return _mm256_cvtepu16_epi32(source); + } + } + } + + /** + * @brief Unpack eight values using the table-driven AVX2 shuffle path. + */ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + __m256i mm256_unpack_epi32_avx2(const u8 * __restrict__ input) { + using unpack_table = unpack_tables_avx2; + constexpr std::size_t packed_bytes = BIT_WIDTH; + + __m256i source = _mm256_setzero_si256(); + std::memcpy(&source, input, packed_bytes); + + const __m256i permuted = _mm256_permutevar8x32_epi32(source, unpack_table::get_permute()); + const __m256i shuffled = _mm256_shuffle_epi8(permuted, unpack_table::get_shuffle()); + + constexpr u16 shift = 32 - BIT_WIDTH; + __m256i shifted = _mm256_sllv_epi32(shuffled, unpack_table::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm256_srai_epi32(shifted, shift); + } else { + shifted = _mm256_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace internal + + /** + * @brief Decompress a single block to float using AVX2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * + * @param input pointer to the start of the compressed block. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @return int status code (0 for success). + * + * @note This function requires AVX2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + const __m256 scale_v = _mm256_set1_ps(scale); +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); + const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); + _mm256_storeu_ps(output, dequantized); + input += BIT_WIDTH; + output += 8; + } + + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi32(tail_values[i], scale); + } + } + + return 0; + } + + /** + * @brief Decompress a single block to double using AVX2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * + * @param input pointer to the start of the compressed block. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed double values will be stored. + * @return int status code (0 for success). + * + * @note This function requires AVX2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + const __m256d scale_v = _mm256_set1_pd(scale); +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_avx2(input); + const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); + const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); + + const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); + const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); + + _mm256_storeu_pd(output, dequantized1); + _mm256_storeu_pd(output + 4, dequantized2); + + input += BIT_WIDTH; + output += 8; + } + + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi64(tail_values[i], scale); + } + } + return 0; + } + + /** + * @brief Decompress multiple blocks to float using AVX2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * + * @param input pointer to the start of the compressed data. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @param blocks number of blocks to decompress. + * @return int status code (0 for success). + * + * @note This function requires AVX2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_avx2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_avx2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } + + /** + * @brief Decompress multiple blocks to double using AVX2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * + * @param input pointer to the start of the compressed data. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @param blocks number of blocks to decompress. + * @return int status code (0 for success). + * + * @note This function requires AVX2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_avx2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_avx2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } +} // namespace pernix + +#endif // PERNIX_AVX2_DECOMPRESSION_H diff --git a/include/pernix/x86/avx2/tables.h b/src/internal/pernix/x86/avx2/avx2_tables.h similarity index 82% rename from include/pernix/x86/avx2/tables.h rename to src/internal/pernix/x86/avx2/avx2_tables.h index f4f374b..f62250f 100644 --- a/include/pernix/x86/avx2/tables.h +++ b/src/internal/pernix/x86/avx2/avx2_tables.h @@ -7,13 +7,13 @@ #include namespace pernix::internal { -template <__uint8_t BIT_WIDTH, typename T> +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && (std::is_same_v || std::is_same_v)) struct pack_tables_avx2_16 { - alignas(64) inline static constexpr std::array permute1 = [] { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 0, 1, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, @@ -25,7 +25,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 0, 1, 4, 5, -1, -1, 10, 11, 14, 15, -1, -1, @@ -37,7 +37,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, @@ -49,7 +49,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, @@ -61,7 +61,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 0, 1, -1, -1, 6, 7, 8, 9, 10, 11, -1, -1, @@ -73,7 +73,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, @@ -85,7 +85,7 @@ struct pack_tables_avx2_16 { 14, 15, -1, -1 }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, @@ -97,14 +97,14 @@ struct pack_tables_avx2_16 { 14, 15, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, @@ -116,7 +116,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, @@ -128,7 +128,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 2, 3, -1, -1, 8, 9, -1, -1, 14, 15, -1, -1, @@ -140,7 +140,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, @@ -152,7 +152,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 2, 3, 4, 5, -1, -1, -1, -1, 12, 13, 14, 15, @@ -164,7 +164,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -176,7 +176,7 @@ struct pack_tables_avx2_16 { 12, 13, -1, -1 }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -188,14 +188,14 @@ struct pack_tables_avx2_16 { 12, 13, 14, 15 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ -1, -1, 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, @@ -207,7 +207,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ -1, -1, 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, @@ -219,7 +219,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ -1, -1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, @@ -231,7 +231,7 @@ struct pack_tables_avx2_16 { -1, -1, -1, -1 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ -1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, @@ -243,118 +243,118 @@ struct pack_tables_avx2_16 { 14, 15, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 0, 2, 4, 6, 0, 1, 3, 5, 0, 2, 4, 6, 0, 1, 3, 5 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 0, 4, 0, 2, 6, 0, 4, 0, 0, 4, 0, 2, 6, 0, 4, 0 }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 0, 6, 1, 7, 2, 0, 3, 0, 0, 6, 1, 7, 2, 0, 3, 0 }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 12, 8, 4, 12, 8, 4, 12, 8, 12, 8, 4, 12, 8, 4, 12, 8 }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 0, 0, 7, 4, 1, 0, 0, 5, 0, 0, 7, 4, 1, 0, 0, 5 }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 14, 12, 10, 8, 6, 4, 2, 14, 14, 12, 10, 8, 6, 4, 2, 14 }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 15, 14, 13, 12, 11, 10, 9, 8, 15, 14, 13, 12, 11, 10, 9, 8 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 9, 11, 13, 15, 8, 10, 12, 14, 9, 11, 13, 15, 8, 10, 12, 14, }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 10, 14, 8, 12, 0, 10, 14, 8, 10, 14, 8, 12, 0, 10, 14, 8, }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 11, 8, 12, 8, 13, 8, 14, 9, 11, 8, 12, 8, 13, 8, 14, 9, }; } else if constexpr (BIT_WIDTH == 12) { - return std::array{ + return std::array{ 0, 4, 8, 0, 4, 8, 0, 4, 0, 4, 8, 0, 4, 8, 0, 4, }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 13, 10, 0, 0, 14, 11, 8, 0, 13, 10, 0, 0, 14, 11, 8, 0, }; } else if constexpr (BIT_WIDTH == 14) { - return std::array{ + return std::array{ 0, 2, 4, 6, 8, 10, 12, 0, 0, 2, 4, 6, 8, 10, 12, 0, }; } else if constexpr (BIT_WIDTH == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (BIT_WIDTH == 9) { - return std::array{ + return std::array{ 0, 7, 5, 3, 1, 8, 6, 4, 0, 7, 5, 3, 1, 8, 6, 4 }; } else if constexpr (BIT_WIDTH == 10) { - return std::array{ + return std::array{ 0, 6, 2, 8, 4, 0, 6, 2, 0, 6, 2, 8, 4, 0, 6, 2, }; } else if constexpr (BIT_WIDTH == 11) { - return std::array{ + return std::array{ 0, 5, 10, 4, 9, 3, 8, 2, 0, 5, 10, 4, 9, 3, 8, 2, }; } else if constexpr (BIT_WIDTH == 13) { - return std::array{ + return std::array{ 0, 3, 6, 9, 12, 2, 5, 8, 0, 3, 6, 9, 12, 2, 5, 8, }; } - return std::array{}; + return std::array{}; // clang-format on }(); @@ -416,10 +416,10 @@ struct pack_tables_avx2_16 { #pragma GCC diagnostic pop }; -template <__uint8_t BIT_WIDTH, typename T> +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) struct pack_tables_avx2_24 { - alignas(64) inline static constexpr std::array permute1 = [] { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -454,11 +454,11 @@ struct pack_tables_avx2_24 { 1, 2, 3, 5, 6, 7, -1, -1, }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -493,11 +493,11 @@ struct pack_tables_avx2_24 { 0, 1, 2, 4, 5, 6, -1, -1, }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { return std::array{ @@ -528,120 +528,120 @@ struct pack_tables_avx2_24 { 0, 1, 2, 4, 5, 6, 0, 0 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { - return std::array{ + return std::array{ 0, 2, 4, 6, 8, 10, 12, 14 }; } else if constexpr (BIT_WIDTH == 18) { - return std::array{ + return std::array{ 0, 4, 8, 12, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 19) { - return std::array{ + return std::array{ 0, 6, 12, 18, 5, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 20) { - return std::array{ + return std::array{ 0, 8, 16, 4, 12, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 21) { - return std::array{ + return std::array{ 0, 10, 20, 9, 19, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 22) { - return std::array{ + return std::array{ 0, 12, 2, 14, 4, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 23) { - return std::array{ + return std::array{ 0, 14, 5, 19, 10, 1, 32, 32 }; } else if constexpr (BIT_WIDTH == 24) { - return std::array{ + return std::array{ 24, 16, 8, 24, 16, 8, 24, 16 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { - return std::array{ + return std::array{ 17, 19, 21, 23, 25, 27, 29, 31 }; } else if constexpr (BIT_WIDTH == 18) { - return std::array{ + return std::array{ 18, 22, 26, 30, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 19) { - return std::array{ + return std::array{ 19, 25, 31, 32, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 20) { - return std::array{ + return std::array{ 20, 28, 32, 24, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 21) { - return std::array{ + return std::array{ 21, 31, 32, 30, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 22) { - return std::array{ + return std::array{ 22, 32, 24, 32, 26, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 23) { - return std::array{ + return std::array{ 23, 32, 28, 32, 32, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 24) { - return std::array{ + return std::array{ 0, 8, 16, 0, 8, 16, 0, 8, }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (BIT_WIDTH == 17) { - return std::array{ + return std::array{ 0, 15, 13, 11, 9, 7, 5, 3 }; } else if constexpr (BIT_WIDTH == 18) { - return std::array{ + return std::array{ 32, 14, 10, 6, 2, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 19) { - return std::array{ + return std::array{ 32, 13, 7, 1, 14, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 20) { - return std::array{ + return std::array{ 32, 12, 4, 16, 8, 32, 32, 32 }; } else if constexpr (BIT_WIDTH == 21) { - return std::array{ + return std::array{ 32, 11, 1, 12, 2, 13, 32, 32 }; } else if constexpr (BIT_WIDTH == 22) { - return std::array{ + return std::array{ 32, 10, 20, 8, 18, 6, 32, 32 }; } else if constexpr (BIT_WIDTH == 23) { - return std::array{ + return std::array{ 32, 9, 18, 4, 13, 22, 32, 32 }; } - return std::array{}; + return std::array{}; // clang-format on }(); @@ -703,10 +703,10 @@ struct pack_tables_avx2_24 { #pragma GCC diagnostic pop }; -template +template requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v)) struct unpack_tables_avx2 { - alignas(32) inline static constexpr std::array permute = [] { + alignas(32) inline static constexpr std::array permute = [] { // clang-format off if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return std::array{0, -1, -1, -1, 0, 1, -1, -1}; @@ -718,8 +718,8 @@ struct unpack_tables_avx2 { // clang-format on }(); - alignas(32) inline static constexpr std::array shuffle = [] { - std::array shuffles{}; + alignas(32) inline static constexpr std::array shuffle = [] { + std::array shuffles{}; shuffles.fill(-1); constexpr std::size_t rebase_second_half = 4 * ((BIT_WIDTH - 1) / 8); @@ -737,7 +737,7 @@ struct unpack_tables_avx2 { const std::size_t dst = (lane * 4 + i) * 4; for (std::size_t k = 0; k < byte_count; ++k) { - shuffles[dst + k] = static_cast(rel_byte_start + k); + shuffles[dst + k] = static_cast(rel_byte_start + k); } } } @@ -745,8 +745,8 @@ struct unpack_tables_avx2 { return shuffles; }(); - alignas(64) inline static constexpr std::array shift = [] { - std::array shifts{}; + alignas(64) inline static constexpr std::array shift = [] { + std::array shifts{}; for (std::size_t lane = 0; lane < 8; ++lane) { const int bit_offset = lane * BIT_WIDTH; @@ -760,7 +760,9 @@ struct unpack_tables_avx2 { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-attributes" - __always_inline static __m256i get_permute() { return _mm256_load_si256(reinterpret_cast(permute.data())); } + __always_inline static __m256i get_permute() { + return _mm256_load_si256(reinterpret_cast(permute.data())); + } __always_inline static T get_shuffle() { if constexpr (std::is_same_v) { @@ -779,7 +781,6 @@ struct unpack_tables_avx2 { } #pragma GCC diagnostic pop }; - -} // namespace pernix::internal +} // namespace pernix::internal #endif // PERNIX_AVX2_TABLES_H diff --git a/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h new file mode 100644 index 0000000..dab8472 --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_compression.h @@ -0,0 +1,735 @@ +#ifndef PERNIX_AVX512VBMI_COMPRESSION_H +#define PERNIX_AVX512VBMI_COMPRESSION_H + +#include +#include +#include +#include +#include + +#include + +using namespace pernix::x86::internal; + +namespace pernix { + namespace internal { + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + static __always_inline __m512i mm512_clamp_signed_epi32(__m512i input) { + constexpr i32 min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); + return _mm512_min_epi32(_mm512_max_epi32(input, _mm512_set1_epi32(min_value)), + _mm512_set1_epi32(max_value)); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + static __always_inline __m256i mm256_clamp_signed_epi32_avx512(__m256i input) { + constexpr i32 min_value = BIT_WIDTH == 1 ? 0 : -(1 << (BIT_WIDTH - 1)); + constexpr i32 max_value = BIT_WIDTH == 1 ? 1 : ((1 << (BIT_WIDTH - 1)) - 1); + return _mm256_min_epi32(_mm256_max_epi32(input, _mm256_set1_epi32(min_value)), + _mm256_set1_epi32(max_value)); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + static __always_inline __m512i mm512_clamp_signed_epi64(__m512i input) { + constexpr i64 min_value = BIT_WIDTH == 1 ? 0 : -(i64{1} << (BIT_WIDTH - 1)); + constexpr i64 max_value = BIT_WIDTH == 1 ? 1 : ((i64{1} << (BIT_WIDTH - 1)) - 1); + return _mm512_min_epi64(_mm512_max_epi64(input, _mm512_set1_epi64(min_value)), + _mm512_set1_epi64(max_value)); + } + + /** + * @brief Quantize sixteen float values to 32-bit integers. + */ + + static __always_inline __m512i mm512_quantize_ps_epi32(const __m512 &input, const __m512 &scale) { + const __m512 scaled = _mm512_mul_ps(input, scale); + return _mm512_cvtps_epi32(scaled); + } + + static __always_inline __m512i mm512_quantize_pd_epi64(const __m512d &input, const __m512d &scale) { + const __m512d scaled = _mm512_mul_pd(input, scale); + return _mm512_cvtpd_epi64(scaled); + } + + static __always_inline __m256i mm512_quantize_pd_epi32(const __m512d &input, const __m512d &scale) { + const __m512d scaled = _mm512_mul_pd(input, scale); + return _mm512_cvtpd_epi32(scaled); + } + + static __always_inline __m512i make_m512i_from_2x256(const __m256i a, const __m256i b) { + __m512i result = _mm512_castsi256_si512(a); + result = _mm512_inserti64x4(result, b, 1); + return result; + } + + static __always_inline __m512i make_m512i_from_4x128(const __m128i a, const __m128i b, const __m128i c, + const __m128i d) { + __m512i result = _mm512_castsi128_si512(a); + result = _mm512_inserti64x2(result, b, 1); + result = _mm512_inserti64x2(result, c, 2); + result = _mm512_inserti64x2(result, d, 3); + return result; + } + + static __always_inline __m512i make_m512i_from_8x64(const __m128i a, const __m128i b, const __m128i c, + const __m128i d, const __m128i e, + const __m128i f, const __m128i g, const __m128i h) { + const __m128i ab = _mm_unpacklo_epi64(a, b); + const __m128i cd = _mm_unpacklo_epi64(c, d); + const __m128i ef = _mm_unpacklo_epi64(e, f); + const __m128i gh = _mm_unpacklo_epi64(g, h); + + __m512i x = _mm512_castsi128_si512(ab); + x = _mm512_inserti32x4(x, cd, 1); + x = _mm512_inserti32x4(x, ef, 2); + x = _mm512_inserti32x4(x, gh, 3); + return x; + } + + static __always_inline __m256i make_m256i_from_2x128(const __m128i a, const __m128i b) { + __m256i result = _mm256_castsi128_si256(a); + result = _mm256_inserti128_si256(result, b, 1); + return result; + } + + static __always_inline __m256i make_m256i_from_4x64(const __m128i a, const __m128i b, const __m128i c, + const __m128i d) { + const __m128i ab = _mm_unpacklo_epi64(a, b); + const __m128i cd = _mm_unpacklo_epi64(c, d); + + __m256i x = _mm256_castsi128_si256(ab); + x = _mm256_inserti128_si256(x, cd, 1); + return x; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_1to8(const f32 * __restrict__ input, const f32 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_64 = elements_per_block / 64; + constexpr u32 iterations_32 = (elements_per_block % 64) / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 remaining_elements = + elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; + + const __m512 scale_v = _mm512_set1_ps(scale); + + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 iter = 0; iter < iterations_64; ++iter) { + const __m512 source1 = _mm512_loadu_ps(input); + const __m512 source2 = _mm512_loadu_ps(input + 16); + const __m512 source3 = _mm512_loadu_ps(input + 32); + const __m512 source4 = _mm512_loadu_ps(input + 48); + + const __m512i quantized1 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source4, scale_v)); + + const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); + const __m128i converted3 = _mm512_cvtepi32_epi8(quantized3); + const __m128i converted4 = _mm512_cvtepi32_epi8(quantized4); + + const __m512i packed = + m512::mm512_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (make_m512i_from_4x128( + converted1, converted2, converted3, converted4)); + + mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); + + input += 64; + output += 8 * BIT_WIDTH; + } + } + + if constexpr (iterations_32 > 0) { + const __m512 source1 = _mm512_loadu_ps(input); + const __m512 source2 = _mm512_loadu_ps(input + 16); + + const __m512i quantized1 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi32_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi32_epi8(quantized2); + + const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (make_m256i_from_2x128( + converted1, converted2)); + mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m512 source = _mm512_loadu_ps(input); + const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); + const __m128i converted = _mm512_cvtepi32_epi8(quantized); + + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (converted); + mm_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m512 source = mm512_loadu_elements_ps(remaining_elements, input); + const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); + const __m128i converted = _mm512_cvtepi32_epi8(quantized); + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (converted); + + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_9to16(const f32 * __restrict__ input, const f32 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); + + if constexpr (iterations_32 > 0) { +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_32; ++iter) { + const __m512 source1 = _mm512_loadu_ps(input); + const __m512 source2 = _mm512_loadu_ps(input + 16); + + const __m512i quantized1 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source2, scale_v)); + + const __m256i converted1 = _mm512_cvtepi32_epi16(quantized1); + const __m256i converted2 = _mm512_cvtepi32_epi16(quantized2); + + const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (make_m512i_from_2x256( + converted1, converted2)); + mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + } + + if constexpr (iterations_16 > 0) { + const __m512 source = _mm512_loadu_ps(input); + const __m512i quantized = mm512_clamp_signed_epi32(mm512_quantize_ps_epi32(source, scale_v)); + const __m256i converted = _mm512_cvtepi32_epi16(quantized); + + const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (iterations_8 > 0) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + const __m128i converted = _mm256_cvtepi32_epi16(quantized); + + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + const __m128i converted = _mm256_cvtepi32_epi16(quantized); + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_17to24(const f32 * __restrict__ input, const f32 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); + + if constexpr (iterations_16 > 0) { +#pragma GCC unroll 2 + for (u32 i = 0; i < iterations_16; ++i) { + const __m512 source = _mm512_loadu_ps(input); + const __m512i packed_input = mm512_clamp_signed_epi32( + mm512_quantize_ps_epi32(source, scale_v)); + + const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (packed_input); + mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); + input += 16; + output += 2 * BIT_WIDTH; + } + } + + if constexpr (iterations_8 > 0) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i packed_input = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (packed_input); + mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); + const __m256i packed_input = mm256_clamp_signed_epi32_avx512( + mm256_quantize_ps_epi32(source, scale_v256)); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (packed_input); + + mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_1to8(const f64 * __restrict__ input, const f64 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_64 = elements_per_block / 64; + constexpr u32 iterations_32 = (elements_per_block % 64) / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 remaining_elements = + elements_per_block - iterations_64 * 64 - iterations_32 * 32 - iterations_16 * 16; + + const __m512d scale_v = _mm512_set1_pd(scale); + + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 iter = 0; iter < iterations_64; ++iter) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512d source3 = _mm512_loadu_pd(input + 16); + const __m512d source4 = _mm512_loadu_pd(input + 24); + const __m512d source5 = _mm512_loadu_pd(input + 32); + const __m512d source6 = _mm512_loadu_pd(input + 40); + const __m512d source7 = _mm512_loadu_pd(input + 48); + const __m512d source8 = _mm512_loadu_pd(input + 56); + + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source4, scale_v)); + const __m512i quantized5 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source5, scale_v)); + const __m512i quantized6 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source6, scale_v)); + const __m512i quantized7 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source7, scale_v)); + const __m512i quantized8 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source8, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); + const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); + const __m128i converted5 = _mm512_cvtepi64_epi8(quantized5); + const __m128i converted6 = _mm512_cvtepi64_epi8(quantized6); + const __m128i converted7 = _mm512_cvtepi64_epi8(quantized7); + const __m128i converted8 = _mm512_cvtepi64_epi8(quantized8); + + const __m512i packed = m512::mm512_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > ( + make_m512i_from_8x64(converted1, converted2, converted3, converted4, + converted5, converted6, converted7, converted8)); + + mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); + + input += 64; + output += 8 * BIT_WIDTH; + } + } + + if constexpr (iterations_32 > 0) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512d source3 = _mm512_loadu_pd(input + 16); + const __m512d source4 = _mm512_loadu_pd(input + 24); + + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source4, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + const __m128i converted3 = _mm512_cvtepi64_epi8(quantized3); + const __m128i converted4 = _mm512_cvtepi64_epi8(quantized4); + + const __m256i packed = + m256::mm256_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (make_m256i_from_4x64( + converted1, converted2, converted3, converted4)); + + mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (_mm_unpacklo_epi64( + converted1, converted2)); + + mm_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + constexpr u32 source1_elements = remaining_elements > 8 ? 8 : remaining_elements; + constexpr u32 source2_elements = remaining_elements > 8 ? remaining_elements - 8 : 0; + + const __m512d source1 = mm512_loadu_elements_pd(source1_elements, input); + const __m512d source2 = source2_elements > 0 + ? mm512_loadu_elements_pd(source2_elements, input + 8) + : _mm512_setzero_pd(); + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi8(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi8(quantized2); + + const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8 < BIT_WIDTH > (_mm_unpacklo_epi64( + converted1, converted2)); + + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_9to16(const f64 * __restrict__ input, const f64 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + + const __m512d scale_v = _mm512_set1_pd(scale); + + if constexpr (iterations_32 > 0) { +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_32; ++iter) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512d source3 = _mm512_loadu_pd(input + 16); + const __m512d source4 = _mm512_loadu_pd(input + 24); + + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + const __m512i quantized3 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source3, scale_v)); + const __m512i quantized4 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source4, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); + const __m128i converted3 = _mm512_cvtepi64_epi16(quantized3); + const __m128i converted4 = _mm512_cvtepi64_epi16(quantized4); + + const __m512i packed = + m512::mm512_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (make_m512i_from_4x128( + converted1, converted2, converted3, converted4)); + + mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); + + input += 32; + output += 4 * BIT_WIDTH; + } + } + + if constexpr (iterations_16 > 0) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + const __m512i quantized1 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source1, scale_v)); + const __m512i quantized2 = mm512_clamp_signed_epi64( + mm512_quantize_pd_epi64(source2, scale_v)); + + const __m128i converted1 = _mm512_cvtepi64_epi16(quantized1); + const __m128i converted2 = _mm512_cvtepi64_epi16(quantized2); + + const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (make_m256i_from_2x128( + converted1, converted2)); + + mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + + if constexpr (iterations_8 > 0) { + const __m512d source = _mm512_loadu_pd(input); + const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); + const __m128i converted = _mm512_cvtepi64_epi16(quantized); + + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); + const __m512i quantized = mm512_clamp_signed_epi64(mm512_quantize_pd_epi64(source, scale_v)); + const __m128i converted = _mm512_cvtepi64_epi16(quantized); + + const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16 < BIT_WIDTH > (converted); + mm_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_compress_block_avx512vbmi_17to24(const f64 * __restrict__ input, const f64 scale, + u8 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + + const __m512d scale_v = _mm512_set1_pd(scale); + + if constexpr (iterations_16 > 0) { +#pragma GCC unroll 2 + for (u32 i = 0; i < iterations_16; ++i) { + const __m512d source1 = _mm512_loadu_pd(input); + const __m512d source2 = _mm512_loadu_pd(input + 8); + + const __m256i quantized1 = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source1, scale_v)); + const __m256i quantized2 = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source2, scale_v)); + + const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > ( + make_m512i_from_2x256(quantized1, quantized2)); + mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); + + input += 16; + output += 2 * BIT_WIDTH; + } + } + + if constexpr (iterations_8 > 0) { + const __m512d source = _mm512_loadu_pd(input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source, scale_v)); + + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (quantized); + mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); + + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m512d source = mm512_loadu_elements_pd(remaining_elements, input); + const __m256i quantized = mm256_clamp_signed_epi32_avx512( + mm512_quantize_pd_epi32(source, scale_v)); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24 < BIT_WIDTH > (quantized); + + mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); + } + + return 0; + } + } // namespace internal + + /** + * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * + * @param input pointer to the start of the input float values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @return int status code (0 for success). + * + * @note This function requires AVX-512 and AVX-512-VBMI support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_block_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + std::memset(output, 0, BLOCK_SIZE); + + if constexpr (BIT_WIDTH <= 8) { + return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH <= 16) { + return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); + } else { + return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); + } + } + + /** + * @brief Compress a single block of double values using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @param input pointer to the start of the input double values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @return int status code. + * + * @note This overload is declared for parity with the float path. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_block_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + std::memset(output, 0, BLOCK_SIZE); + + if constexpr (BIT_WIDTH <= 8) { + return internal::mm512_compress_block_avx512vbmi_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH <= 16) { + return internal::mm512_compress_block_avx512vbmi_9to16(input, scale, output); + } else { + return internal::mm512_compress_block_avx512vbmi_17to24(input, scale, output); + } + } + + /** + * @brief Compress multiple 512-bit blocks using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * + * @param input pointer to the start of the input float values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @param blocks number of 512-bit blocks to compress. + * @return int status code (0 for success). + * + * @note This function requires AVX-512 and AVX-512-VBMI support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_compress_block_avx512vbmi(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } + + /** + * @brief Compress multiple blocks of double values using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @param input pointer to the start of the input double values. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where compressed bytes will be stored. + * @param blocks number of blocks to compress. + * @return int status code. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_compress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_compress_block_avx512vbmi(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } +} // namespace pernix + +#endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h new file mode 100644 index 0000000..a004be5 --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/avx512vbmi_decompression.h @@ -0,0 +1,677 @@ +#ifndef PERNIX_AVX512VBMI_DECOMPRESSION_H +#define PERNIX_AVX512VBMI_DECOMPRESSION_H + +#include +#include +#include +#include +#include + +#include + +using namespace pernix::x86::internal; + +namespace pernix { + namespace internal { + /** + * @brief Dequantize sixteen integer values to floats. + */ +__always_inline __m512 mm512_dequantize_epi32(const __m512i &input, const __m512 &scale) { + const __m512 converted = _mm512_cvtepi32_ps(input); + return _mm512_mul_ps(converted, scale); + } + +__always_inline __m512d mm512_dequantize_epi64(const __m512i &input, const __m512d &scale) { + const __m512d converted = _mm512_cvtepi64_pd(input); + return _mm512_mul_pd(converted, scale); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const u32 iterations_64 = elements_per_block / 64; + const u32 iterations_32 = (elements_per_block % 64) / 32; + const u32 iterations_16 = (elements_per_block % 32) / 16; + const u32 remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - + iterations_16 * 16; + + const __m512 scale_v = _mm512_set1_ps(scale); + + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 i = 0; i < iterations_64; ++i) { + const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi32(_mm512_castsi512_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 1)); + const __m512i converted3 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 2)); + const __m512i converted4 = _mm512_cvtepi8_epi32(_mm512_extracti64x2_epi64(unpacked, 3)); + + const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); + const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); + const __m512 dequantized3 = mm512_dequantize_epi32(converted3, scale_v); + const __m512 dequantized4 = mm512_dequantize_epi32(converted4, scale_v); + + _mm512_storeu_ps(output, dequantized1); + _mm512_storeu_ps(output + 16, dequantized2); + _mm512_storeu_ps(output + 32, dequantized3); + _mm512_storeu_ps(output + 48, dequantized4); + + output += 64; + input += 8 * BIT_WIDTH; + } + } + + if constexpr (iterations_32 > 0) { + const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi32(_mm256_castsi256_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi8_epi32(_mm256_extracti128_si256(unpacked, 1)); + + const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); + const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); + + _mm512_storeu_ps(output, dequantized1); + _mm512_storeu_ps(output + 16, dequantized2); + + output += 32; + input += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi8_epi32(unpacked); + + const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); + + _mm512_storeu_ps(output, dequantized); + + output += 16; + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi8_epi32(unpacked); + + const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); + + mm512_storeu_elements_ps(output, remaining_elements, dequantized); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const u32 iterations_64 = elements_per_block / 64; + const u32 iterations_32 = (elements_per_block % 64) / 32; + const u32 iterations_16 = (elements_per_block % 32) / 16; + const u32 remaining_elements = elements_per_block - iterations_64 * 64 - iterations_32 * 32 - + iterations_16 * 16; + + const __m512d scale_v = _mm512_set1_pd(scale); + + if constexpr (iterations_64 > 0) { +#pragma GCC unroll 8 + for (u32 i = 0; i < iterations_64; ++i) { + const __m512i source = mm512_loadu_elements_epi64(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m128i extracted1 = _mm512_castsi512_si128(unpacked); + const __m128i extracted2 = _mm512_extracti64x2_epi64(unpacked, 1); + const __m128i extracted3 = _mm512_extracti64x2_epi64(unpacked, 2); + const __m128i extracted4 = _mm512_extracti64x2_epi64(unpacked, 3); + + const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); + const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); + const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); + const __m512i converted5 = _mm512_cvtepi8_epi64(extracted3); + const __m512i converted6 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted3, 8)); + const __m512i converted7 = _mm512_cvtepi8_epi64(extracted4); + const __m512i converted8 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted4, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); + const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + const __m512d dequantized5 = mm512_dequantize_epi64(converted5, scale_v); + const __m512d dequantized6 = mm512_dequantize_epi64(converted6, scale_v); + const __m512d dequantized7 = mm512_dequantize_epi64(converted7, scale_v); + const __m512d dequantized8 = mm512_dequantize_epi64(converted8, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output + 16, dequantized3); + _mm512_storeu_pd(output + 24, dequantized4); + _mm512_storeu_pd(output + 32, dequantized5); + _mm512_storeu_pd(output + 40, dequantized6); + _mm512_storeu_pd(output + 48, dequantized7); + _mm512_storeu_pd(output + 56, dequantized8); + + output += 64; + input += 8 * BIT_WIDTH; + } + + if constexpr (iterations_32 > 0) { + const __m256i source = mm256_loadu_elements_epi32(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m128i extracted1 = _mm256_castsi256_si128(unpacked); + const __m128i extracted2 = _mm256_extracti64x2_epi64(unpacked, 1); + + const __m512i converted1 = _mm512_cvtepi8_epi64(extracted1); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted1, 8)); + const __m512i converted3 = _mm512_cvtepi8_epi64(extracted2); + const __m512i converted4 = _mm512_cvtepi8_epi64(_mm_srli_si128(extracted2, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); + const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output + 16, dequantized3); + _mm512_storeu_pd(output + 24, dequantized4); + + output += 32; + input += 4 * BIT_WIDTH; + } + + if constexpr (iterations_16 > 0) { + const __m128i source = mm_loadu_elements_epi16(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + + output += 16; + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi8_avx512vbmi_1to8 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi8_epi64(unpacked); + const __m512i converted2 = _mm512_cvtepi8_epi64(_mm_srli_si128(unpacked, 8)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + + mm512_storeu_elements_pd(output, remaining_elements < 8 ? remaining_elements : 8, dequantized1); + if constexpr (remaining_elements > 8) { + mm512_storeu_elements_pd(output + 8, remaining_elements - 8, dequantized2); + } + } + } + + return 0; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); + + if constexpr (iterations_32 > 0) { +#pragma GCC unroll 4 + for (u32 i = 0; i < iterations_32; ++i) { + const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi16_epi32(_mm512_castsi512_si256(unpacked)); + const __m512i converted2 = _mm512_cvtepi16_epi32(_mm512_extracti32x8_epi32(unpacked, 1)); + + const __m512 dequantized1 = mm512_dequantize_epi32(converted1, scale_v); + const __m512 dequantized2 = mm512_dequantize_epi32(converted2, scale_v); + + _mm512_storeu_ps(output, dequantized1); + _mm512_storeu_ps(output + 16, dequantized2); + + output += 32; + input += 4 * BIT_WIDTH; + } + } + + if constexpr (iterations_16 > 0) { + const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi16_epi32(unpacked); + const __m512 dequantized = mm512_dequantize_epi32(converted, scale_v); + + _mm512_storeu_ps(output, dequantized); + + output += 16; + input += 2 * BIT_WIDTH; + } + + if constexpr (iterations_8 > 0) { + const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m256i converted = _mm256_cvtepi16_epi32(unpacked); + const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); + + _mm256_storeu_ps(output, dequantized); + + output += 8; + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m256i converted = _mm256_cvtepi16_epi32(unpacked); + const __m256 dequantized = mm256_dequantize_epi32(converted, scale_v256); + + mm256_storeu_elements_ps(output, remaining_elements, dequantized); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_32 = elements_per_block / 32; + constexpr u32 iterations_16 = (elements_per_block % 32) / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = + elements_per_block - iterations_32 * 32 - iterations_16 * 16 - iterations_8 * 8; + + const __m512d scale_v = _mm512_set1_pd(scale); + + if constexpr (iterations_32 > 0) { +#pragma GCC unroll 4 + for (u32 i = 0; i < iterations_32; ++i) { + const __m512i source = mm512_loadu_elements_epi32(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi16_epi64(_mm512_castsi512_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 1)); + const __m512i converted3 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 2)); + const __m512i converted4 = _mm512_cvtepi16_epi64(_mm512_extracti64x2_epi64(unpacked, 3)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + const __m512d dequantized3 = mm512_dequantize_epi64(converted3, scale_v); + const __m512d dequantized4 = mm512_dequantize_epi64(converted4, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + _mm512_storeu_pd(output + 16, dequantized3); + _mm512_storeu_pd(output + 24, dequantized4); + + output += 32; + input += 4 * BIT_WIDTH; + } + } + + if constexpr (iterations_16 > 0) { + const __m256i source = mm256_loadu_elements_epi16(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi16_epi64(_mm256_castsi256_si128(unpacked)); + const __m512i converted2 = _mm512_cvtepi16_epi64(_mm256_extracti64x2_epi64(unpacked, 1)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + + output += 16; + input += 2 * BIT_WIDTH; + } + + if constexpr (iterations_8 > 0) { + const __m128i source = mm_loadu_elements_epi8(BIT_WIDTH, input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi16_epi64(unpacked); + + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + + _mm512_storeu_pd(output, dequantized); + + output += 8; + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m128i source = mm_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m128i unpacked = m128::mm_unpack_epi16_avx512vbmi_9to16 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi16_epi64(unpacked); + + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + + mm512_storeu_elements_pd(output, remaining_elements, dequantized); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const u8 * __restrict__ input, const f32 scale, + f32 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + + const __m512 scale_v = _mm512_set1_ps(scale); + const __m256 scale_v256 = _mm256_set1_ps(scale); + + if constexpr (iterations_16 > 0) { +#pragma GCC unroll 2 + for (u32 i = 0; i < iterations_16; ++i) { + const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512 dequantized = mm512_dequantize_epi32(unpacked, scale_v); + + _mm512_storeu_ps(output, dequantized); + + output += 16; + input += 2 * BIT_WIDTH; + } + } + + if constexpr (iterations_8 > 0) { + const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); + + _mm256_storeu_ps(output, dequantized); + + output += 8; + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m256 dequantized = mm256_dequantize_epi32(unpacked, scale_v256); + + mm256_storeu_elements_ps(output, remaining_elements, dequantized); + } + + return 0; + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const u8 * __restrict__ input, const f64 scale, + f64 * __restrict__ output) { + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr u32 iterations_16 = elements_per_block / 16; + constexpr u32 iterations_8 = (elements_per_block % 16) / 8; + constexpr u32 remaining_elements = elements_per_block - iterations_16 * 16 - iterations_8 * 8; + + const __m512d scale_v = _mm512_set1_pd(scale); + + if constexpr (iterations_16 > 0) { +#pragma GCC unroll 2 + for (u32 i = 0; i < iterations_16; ++i) { + const __m512i source = mm512_loadu_elements_epi16(BIT_WIDTH, input); + const __m512i unpacked = m512::mm512_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted1 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(unpacked)); + const __m512i converted2 = _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(unpacked, 1)); + + const __m512d dequantized1 = mm512_dequantize_epi64(converted1, scale_v); + const __m512d dequantized2 = mm512_dequantize_epi64(converted2, scale_v); + + _mm512_storeu_pd(output, dequantized1); + _mm512_storeu_pd(output + 8, dequantized2); + + output += 16; + input += 2 * BIT_WIDTH; + } + } + + if constexpr (iterations_8 > 0) { + const __m256i source = mm256_loadu_elements_epi8(BIT_WIDTH, input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi32_epi64(unpacked); + + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + + _mm512_storeu_pd(output, dequantized); + + output += 8; + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const __m256i source = mm256_loadu_elements_epi8(tail_bytes(BIT_WIDTH, remaining_elements), input); + const __m256i unpacked = m256::mm256_unpack_epi32_avx512vbmi_17to24 < BIT_WIDTH, SIGN_VALUES + > + (source); + + const __m512i converted = _mm512_cvtepi32_epi64(unpacked); + + const __m512d dequantized = mm512_dequantize_epi64(converted, scale_v); + + mm512_storeu_elements_pd(output, remaining_elements, dequantized); + } + + return 0; + } + } // namespace internal + + /** + * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * + * @param input pointer to the start of the compressed block. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @return int status code (0 for success). + * + * @note This function requires AVX-512 and AVX-512-VBMI support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::mm512_decompress_block_avx512vbmi_1to8( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::mm512_decompress_block_avx512vbmi_9to16( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::mm512_decompress_block_avx512vbmi_17to24( + input, scale, output); + } + return 0; + } + + /** + * @brief Decompress a single block to double values using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * @param input pointer to the start of the compressed block. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed double values will be stored. + * @return int status code. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int mm512_decompress_block_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::mm512_decompress_block_avx512vbmi_1to8( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::mm512_decompress_block_avx512vbmi_9to16( + input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::mm512_decompress_block_avx512vbmi_17to24( + input, scale, output); + } + return 0; + } + + /** + * @brief Decompress multiple 512\-bit blocks using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * + * @param input pointer to the start of the compressed data. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @param blocks number of 512-bit blocks to decompress. + * @return int status code (0 for success). + * + * @note This function requires AVX-512 and AVX-512-VBMI support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_decompress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_decompress_block_avx512vbmi(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } + + /** + * @brief Decompress multiple blocks to double values using AVX-512 and AVX-512-VBMI instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * @param input pointer to the start of the compressed data. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed double values will be stored. + * @param blocks number of blocks to decompress. + * @return int status code. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm512_decompress_blocks_avx512vbmi(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; ++block) { + mm512_decompress_block_avx512vbmi(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; + } +} // namespace pernix + +#endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/src/internal/pernix/x86/avx512vbmi/compat.h b/src/internal/pernix/x86/avx512vbmi/compat.h new file mode 100644 index 0000000..df9ca4b --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/compat.h @@ -0,0 +1,387 @@ +#ifndef PERNIX_AVX512_COMPAT_H +#define PERNIX_AVX512_COMPAT_H + +#include +#include +#include +#include + +namespace pernix::internal { + static __always_inline __mmask8 element_mask8(const u32 e) { + return static_cast<__mmask8>(e >= 8 ? 0xFFu : ((1u << e) - 1u)); + } + + static __always_inline __mmask16 element_mask16(const u32 e) { + return static_cast<__mmask16>(e >= 16 ? 0xFFFFu : ((1u << e) - 1u)); + } + + static __always_inline __mmask32 element_mask32(const u32 e) { + return e >= 32 ? 0xFFFFFFFFu : (1u << e) - 1u; + } + + static __always_inline __mmask64 element_mask64(const u32 e) { + return e >= 64 ? 0xFFFFFFFFFFFFFFFFull : (1ull << e) - 1ull; + } + + static __always_inline __m512i mm512_loadu_elements_epi64(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(i64)); + return a; +#else + return _mm512_maskz_loadu_epi64(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi64(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(i64)); + return a; +#else + return _mm256_maskz_loadu_epi64(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi64(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(i64)); + return a; +#else + return _mm_maskz_loadu_epi64(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512i mm512_loadu_elements_epi32(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(i32)); + return a; +#else + return _mm512_maskz_loadu_epi32(element_mask16(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi32(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(i32)); + return a; +#else + return _mm256_maskz_loadu_epi32(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi32(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(i32)); + return a; +#else + return _mm_maskz_loadu_epi32(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512i mm512_loadu_elements_epi16(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(i16)); + return a; +#else + return _mm512_maskz_loadu_epi16(element_mask32(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi16(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(i16)); + return a; +#else + return _mm256_maskz_loadu_epi16(element_mask16(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi16(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(i16)); + return a; +#else + return _mm_maskz_loadu_epi16(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512i mm512_loadu_elements_epi8(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512i a = _mm512_setzero_si512(); + std::memcpy(&a, mem_addr, e * sizeof(i8)); + return a; +#else + return _mm512_maskz_loadu_epi8(element_mask64(e), mem_addr); +#endif + } + + static __always_inline __m256i mm256_loadu_elements_epi8(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256i a = _mm256_setzero_si256(); + std::memcpy(&a, mem_addr, e * sizeof(i8)); + return a; +#else + return _mm256_maskz_loadu_epi8(element_mask32(e), mem_addr); +#endif + } + + static __always_inline __m128i mm_loadu_elements_epi8(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128i a = _mm_setzero_si128(); + std::memcpy(&a, mem_addr, e * sizeof(i8)); + return a; +#else + return _mm_maskz_loadu_epi8(element_mask16(e), mem_addr); +#endif + } + + static __always_inline void mm512_storeu_elements_epi64(void *mem_addr, const u32 e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) u8 bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(i64)); +#else + _mm512_mask_storeu_epi64(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi64(void *mem_addr, const u32 e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) u8 bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i64)); +#else + _mm256_mask_storeu_epi64(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi64(void *mem_addr, const u32 e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) u8 bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i64)); +#else + _mm_mask_storeu_epi64(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_epi32(void *mem_addr, const u32 e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) u8 bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(i32)); +#else + _mm512_mask_storeu_epi32(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi32(void *mem_addr, const u32 e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) u8 bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i32)); +#else + _mm256_mask_storeu_epi32(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi32(void *mem_addr, const u32 e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) u8 bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i32)); +#else + _mm_mask_storeu_epi32(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_epi16(void *mem_addr, const u32 e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) u8 bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(i16)); +#else + _mm512_mask_storeu_epi16(mem_addr, element_mask32(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi16(void *mem_addr, const u32 e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) u8 bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i16)); +#else + _mm256_mask_storeu_epi16(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi16(void *mem_addr, const u32 e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) u8 bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i16)); +#else + _mm_mask_storeu_epi16(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_epi8(void *mem_addr, const u32 e, const __m512i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) u8 bytes[64]; + _mm512_storeu_si512(bytes, a); + std::memcpy(mem_addr, bytes, e * sizeof(i8)); +#else + _mm512_mask_storeu_epi8(mem_addr, element_mask64(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_epi8(void *mem_addr, const u32 e, const __m256i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) u8 bytes[32]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i8)); +#else + _mm256_mask_storeu_epi8(mem_addr, element_mask32(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_epi8(void *mem_addr, const u32 e, const __m128i a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) u8 bytes[16]; + _mm_storeu_si128(reinterpret_cast<__m128i *>(bytes), a); + std::memcpy(mem_addr, bytes, e * sizeof(i8)); +#else + _mm_mask_storeu_epi8(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline __m512 mm512_loadu_elements_ps(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512 a = _mm512_setzero_ps(); + std::memcpy(&a, mem_addr, e * sizeof(f32)); + return a; +#else + return _mm512_maskz_loadu_ps(element_mask16(e), mem_addr); +#endif + } + + static __always_inline __m256 mm256_loadu_elements_ps(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256 a = _mm256_setzero_ps(); + std::memcpy(&a, mem_addr, e * sizeof(f32)); + return a; +#else + return _mm256_maskz_loadu_ps(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128 mm_loadu_elements_ps(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128 a = _mm_setzero_ps(); + std::memcpy(&a, mem_addr, e * sizeof(f32)); + return a; +#else + return _mm_maskz_loadu_ps(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m512d mm512_loadu_elements_pd(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m512d a = _mm512_setzero_pd(); + std::memcpy(&a, mem_addr, e * sizeof(f64)); + return a; +#else + return _mm512_maskz_loadu_pd(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m256d mm256_loadu_elements_pd(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m256d a = _mm256_setzero_pd(); + std::memcpy(&a, mem_addr, e * sizeof(f64)); + return a; +#else + return _mm256_maskz_loadu_pd(element_mask8(e), mem_addr); +#endif + } + + static __always_inline __m128d mm_loadu_elements_pd(const u32 e, const void *mem_addr) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + __m128d a = _mm_setzero_pd(); + std::memcpy(&a, mem_addr, e * sizeof(f64)); + return a; +#else + return _mm_maskz_loadu_pd(element_mask8(e), mem_addr); +#endif + } + + static __always_inline void mm512_storeu_elements_ps(void *mem_addr, const u32 e, const __m512 a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) f32 values[16]; + _mm512_storeu_ps(values, a); + std::memcpy(mem_addr, values, e * sizeof(f32)); +#else + _mm512_mask_storeu_ps(mem_addr, element_mask16(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_ps(void *mem_addr, const u32 e, const __m256 a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) f32 values[8]; + _mm256_storeu_ps(values, a); + std::memcpy(mem_addr, values, e * sizeof(f32)); +#else + _mm256_mask_storeu_ps(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_ps(void *mem_addr, const u32 e, const __m128 a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) f32 values[4]; + _mm_storeu_ps(values, a); + std::memcpy(mem_addr, values, e * sizeof(f32)); +#else + _mm_mask_storeu_ps(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm512_storeu_elements_pd(void *mem_addr, const u32 e, const __m512d a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(64) f64 values[8]; + _mm512_storeu_pd(values, a); + std::memcpy(mem_addr, values, e * sizeof(f64)); +#else + _mm512_mask_storeu_pd(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm256_storeu_elements_pd(void *mem_addr, const u32 e, const __m256d a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(32) f64 values[4]; + _mm256_storeu_pd(values, a); + std::memcpy(mem_addr, values, e * sizeof(f64)); +#else + _mm256_mask_storeu_pd(mem_addr, element_mask8(e), a); +#endif + } + + static __always_inline void mm_storeu_elements_pd(void *mem_addr, const u32 e, const __m128d a) { +#if defined(PERNIX_USE_SIMDE) && !defined(SIMDE_X86_AVX512F_NATIVE) + alignas(16) f64 values[2]; + _mm_storeu_pd(values, a); + std::memcpy(mem_addr, values, e * sizeof(f64)); +#else + _mm_mask_storeu_pd(mem_addr, element_mask8(e), a); +#endif + } +} + +#endif //PERNIX_AVX512_COMPAT_H diff --git a/src/internal/pernix/x86/avx512vbmi/packing.h b/src/internal/pernix/x86/avx512vbmi/packing.h new file mode 100644 index 0000000..d4052eb --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/packing.h @@ -0,0 +1,331 @@ +#ifndef PERNIX_AVX512VBMI_PACKING_H +#define PERNIX_AVX512VBMI_PACKING_H + +#include +#include + +namespace pernix::internal { + namespace m128 { + /** + * @brief Pack 8 16-bit values for bit widths 9 through 16 using VBMI. + */ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = pack_tables_avx512_16; + const __m128i maskv = _mm_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i masked = _mm_and_si128(input, maskv); + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m128i permuted1 = _mm_permutexvar_epi16(tables::get_permute1(), masked); + const __m128i permuted2 = _mm_permutexvar_epi16(tables::get_permute2(), masked); + + const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); + const __m128i shifted2 = _mm_srlv_epi16(permuted2, tables::get_shift2()); + + return _mm_or_si128(shifted1, shifted2); + } else { + const auto [mask1, mask2, mask3] = tables::get_permute_masks(); + + const __m128i permuted1 = _mm_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); + const __m128i permuted2 = _mm_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); + const __m128i permuted3 = _mm_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); + + const __m128i shifted1 = _mm_sllv_epi16(permuted1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi16(permuted3, tables::get_shift3()); + + return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + } + } + + /** + * @brief Pack 16 8-bit values for bit widths 1 through 8 using VBMI. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + const __m128i maskv = _mm_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i masked = _mm_and_si128(input, maskv); + + if constexpr (BIT_WIDTH == 1) { + return _mm_set1_epi16(static_cast(_mm_cmpgt_epi8_mask(masked, _mm_setzero_si128()))); + } else if constexpr (BIT_WIDTH == 2) { + const __m128i shifted = _mm_srli_epi16(masked, 6); + const __m128i combined = _mm_or_si128(masked, shifted); + + const __m128i shifted2 = _mm_srli_epi32(combined, 12); + const __m128i combined2 = _mm_or_si128(shifted2, combined); + + return _mm_cvtepi32_epi8(combined2); + } else if constexpr (BIT_WIDTH == 3) { + const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); + const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); + + const __m128i pair6 = _mm_or_si128(even, _mm_srli_epi16(odd, 5)); + const __m128i packed12 = _mm_or_si128(pair6, _mm_srli_epi32(pair6, 10)); + + return m128::mm_pack_epi16_avx512vbmi_9to16<12>(_mm_cvtepi32_epi16(packed12)); + } else if constexpr (BIT_WIDTH == 4) { + const __m128i shifted = _mm_srli_epi16(masked, 4); + const __m128i combined = _mm_or_si128(masked, shifted); + + return _mm_cvtepi16_epi8(combined); + } else { + const __m128i even = _mm_and_si128(masked, _mm_set1_epi16(0x00FF)); + const __m128i odd = _mm_and_si128(masked, _mm_set1_epi16(0xFF00)); + + const __m128i shifted = _mm_or_si128(even, _mm_srli_epi16(odd, 8 - BIT_WIDTH)); + return mm_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); + } + } + } + + /** + * @brief Pack 4 32-bit values for bit widths 17 through 24 using VBMI. + */ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i &input) { + using tables = pack_tables_avx512_24; + + const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m128i masked = _mm_and_si128(input, maskv); + + const __m128 permuted1 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute1()); + const __m128 permuted2 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute2()); + const __m128 permuted3 = _mm_permutevar_ps(_mm_castsi128_ps(masked), tables::get_permute3()); + + const __m128i shifted1 = _mm_sllv_epi32(_mm_castps_si128(permuted1), tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi32(_mm_castps_si128(permuted2), tables::get_shift2()); + const __m128i shifted3 = _mm_srlv_epi32(_mm_castps_si128(permuted3), tables::get_shift3()); + + return _mm_or_si128(_mm_or_si128(shifted1, shifted2), shifted3); + } + } // namespace m128 + + namespace m256 { + /** + * @brief Pack 16 16-bit values for bit widths 9 through 16 using VBMI. + */ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = pack_tables_avx512_16; + const __m256i maskv = _mm256_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i masked = _mm256_and_si256(input, maskv); + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m256i permuted1 = _mm256_permutexvar_epi16(tables::get_permute1(), masked); + const __m256i permuted2 = _mm256_permutexvar_epi16(tables::get_permute2(), masked); + + const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_srlv_epi16(permuted2, tables::get_shift2()); + + return _mm256_or_si256(shifted1, shifted2); + } else { + const auto [mask1, mask2, mask3] = tables::get_permute_masks(); + + const __m256i permuted1 = _mm256_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); + const __m256i permuted2 = _mm256_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); + const __m256i permuted3 = _mm256_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); + + const __m256i shifted1 = _mm256_sllv_epi16(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); + const __m256i shifted3 = _mm256_srlv_epi16(permuted3, tables::get_shift3()); + + return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); + } + } + } + + /** + * @brief Pack 32 8-bit values for bit widths 1 through 8 using VBMI. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + const __m256i maskv = _mm256_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i masked = _mm256_and_si256(input, maskv); + + if constexpr (BIT_WIDTH == 1) { + return _mm256_set1_epi32( + static_cast(_mm256_cmpgt_epi8_mask(masked, _mm256_setzero_si256()))); + } else if constexpr (BIT_WIDTH == 2) { + const __m256i shifted = _mm256_srli_epi16(masked, 6); + const __m256i combined = _mm256_or_si256(masked, shifted); + + const __m256i shifted2 = _mm256_srli_epi32(combined, 12); + const __m256i combined2 = _mm256_or_si256(shifted2, combined); + + return _mm256_castsi128_si256(_mm256_cvtepi32_epi8(combined2)); + } else if constexpr (BIT_WIDTH == 3) { + const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); + const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); + + const __m256i pair6 = _mm256_or_si256(even, _mm256_srli_epi16(odd, 5)); + const __m256i packed12 = _mm256_or_si256(pair6, _mm256_srli_epi32(pair6, 10)); + + return m256::mm256_pack_epi16_avx512vbmi_9to16<12>( + _mm256_castsi128_si256(_mm256_cvtepi32_epi16(packed12))); + } else if constexpr (BIT_WIDTH == 4) { + const __m256i shifted = _mm256_srli_epi16(masked, 4); + const __m256i combined = _mm256_or_si256(masked, shifted); + + return _mm256_castsi128_si256(_mm256_cvtepi16_epi8(combined)); + } else { + const __m256i even = _mm256_and_si256(masked, _mm256_set1_epi16(0x00FF)); + const __m256i odd = _mm256_and_si256(masked, _mm256_set1_epi16(0xFF00)); + + const __m256i shifted = _mm256_or_si256(even, _mm256_srli_epi16(odd, 8 - BIT_WIDTH)); + return mm256_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); + } + } + } + + /** + * @brief Pack 8 32-bit values for bit widths 17 through 24 using VBMI. + */ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i &input) { + using tables = pack_tables_avx512_24; + + const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m256i masked = _mm256_and_si256(input, maskv); + + const __m256i permuted1 = _mm256_permutexvar_epi32(tables::get_permute1(), masked); + const __m256i permuted2 = _mm256_permutexvar_epi32(tables::get_permute2(), masked); + const __m256i permuted3 = _mm256_permutexvar_epi32(tables::get_permute3(), masked); + + const __m256i shifted1 = _mm256_sllv_epi32(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi32(permuted2, tables::get_shift2()); + const __m256i shifted3 = _mm256_srlv_epi32(permuted3, tables::get_shift3()); + + return _mm256_or_si256(_mm256_or_si256(shifted1, shifted2), shifted3); + } + } // namespace m256 + + namespace m512 { + /** + * @brief Pack 32 16-bit values for bit widths 9 through 16 using VBMI. + */ + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = pack_tables_avx512_16; + const __m512i maskv = _mm512_set1_epi16(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i masked = _mm512_and_si512(input, maskv); + + if constexpr (BIT_WIDTH == 12 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m512i permuted1 = _mm512_permutexvar_epi16(tables::get_permute1(), masked); + const __m512i permuted2 = _mm512_permutexvar_epi16(tables::get_permute2(), masked); + + const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_srlv_epi16(permuted2, tables::get_shift2()); + + return _mm512_or_si512(shifted1, shifted2); + } else { + const auto [mask1, mask2, mask3] = tables::get_permute_masks(); + + const __m512i permuted1 = _mm512_maskz_permutexvar_epi16(mask1, tables::get_permute1(), masked); + const __m512i permuted2 = _mm512_maskz_permutexvar_epi16(mask2, tables::get_permute2(), masked); + const __m512i permuted3 = _mm512_maskz_permutexvar_epi16(mask3, tables::get_permute3(), masked); + + const __m512i shifted1 = _mm512_sllv_epi16(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); + const __m512i shifted3 = _mm512_srlv_epi16(permuted3, tables::get_shift3()); + + return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); + } + } + } + + /** + * @brief Pack 64 8-bit values for bit widths 1 through 8 using VBMI. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + const __m512i maskv = _mm512_set1_epi8(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i masked = _mm512_and_si512(input, maskv); + + if constexpr (BIT_WIDTH == 1) { + return _mm512_set1_epi64( + static_cast(_mm512_cmpgt_epi8_mask(masked, _mm512_setzero_si512()))); + } else if constexpr (BIT_WIDTH == 2) { + const __m512i shifted = _mm512_srli_epi16(masked, 6); + const __m512i combined = _mm512_or_si512(masked, shifted); + + const __m512i shifted2 = _mm512_srli_epi32(combined, 12); + const __m512i combined2 = _mm512_or_si512(shifted2, combined); + + return _mm512_castsi128_si512(_mm512_cvtepi32_epi8(combined2)); + } else if constexpr (BIT_WIDTH == 3) { + const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); + const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); + + const __m512i pair6 = _mm512_or_si512(even, _mm512_srli_epi16(odd, 5)); + const __m512i packed12 = _mm512_or_si512(pair6, _mm512_srli_epi32(pair6, 10)); + + return _mm512_castsi256_si512( + m256::mm256_pack_epi16_avx512vbmi_9to16<12>(_mm512_cvtepi32_epi16(packed12))); + } else if constexpr (BIT_WIDTH == 4) { + const __m512i shifted = _mm512_srli_epi16(masked, 4); + const __m512i combined = _mm512_or_si512(masked, shifted); + + return _mm512_castsi256_si512(_mm512_cvtepi16_epi8(combined)); + } else { + const __m512i even = _mm512_and_si512(masked, _mm512_set1_epi16(0x00FF)); + const __m512i odd = _mm512_and_si512(masked, _mm512_set1_epi16(0xFF00)); + + const __m512i shifted = _mm512_or_si512(even, _mm512_srli_epi16(odd, 8 - BIT_WIDTH)); + return mm512_pack_epi16_avx512vbmi_9to16<2 * BIT_WIDTH>(shifted); + } + } + } + + /** + * @brief Pack 16 32-bit values for bit widths 17 through 24 using VBMI. + */ + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i &input) { + using tables = pack_tables_avx512_24; + + const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); + const __m512i masked = _mm512_and_si512(input, maskv); + + const __m512i permuted1 = _mm512_permutexvar_epi32(tables::get_permute1(), masked); + const __m512i permuted2 = _mm512_permutexvar_epi32(tables::get_permute2(), masked); + const __m512i permuted3 = _mm512_permutexvar_epi32(tables::get_permute3(), masked); + + const __m512i shifted1 = _mm512_sllv_epi32(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_sllv_epi32(permuted2, tables::get_shift2()); + const __m512i shifted3 = _mm512_srlv_epi32(permuted3, tables::get_shift3()); + + return _mm512_or_si512(_mm512_or_si512(shifted1, shifted2), shifted3); + } + } // namespace m512 +} // namespace pernix::internal + +#endif // PERNIX_AVX512VBMI_PACKING_H diff --git a/include/pernix/x86/avx512vbmi/tables.h b/src/internal/pernix/x86/avx512vbmi/tables.h similarity index 66% rename from include/pernix/x86/avx512vbmi/tables.h rename to src/internal/pernix/x86/avx512vbmi/tables.h index 9115625..d63a47a 100644 --- a/include/pernix/x86/avx512vbmi/tables.h +++ b/src/internal/pernix/x86/avx512vbmi/tables.h @@ -9,9 +9,8 @@ #include namespace pernix::internal { - template -[[gnu::always_inline]] static inline Vec load_table(const std::array& table) { +static __always_inline Vec load_table(const std::array& table) { static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); if constexpr (std::is_same_v) { return _mm512_load_si512(static_cast(table.data())); @@ -22,13 +21,13 @@ template } } -template +template requires(N >= 9 && N <= 15) struct pack_tables_avx512_16 { - alignas(64) inline static constexpr std::array permute1 = [] { + alignas(64) inline static constexpr std::array permute1 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 0, 2, 4, 6, -1, 9, 11, 13, 15, 16, 18, 20, @@ -40,7 +39,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 0, 2, -1, 5, 7, 8, 10, -1, 13, 15, 16, 18, @@ -52,7 +51,7 @@ struct pack_tables_avx512_16 { 45, 47, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 0, 2, 3, 5, 6, -1, 9, -1, 12, -1, 15, 16, @@ -64,7 +63,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15, @@ -76,7 +75,7 @@ struct pack_tables_avx512_16 { 38, 39, 41, 42 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 0, -1, 3, 4, 5, -1, -1, 9, 10, -1, -1, 14, @@ -88,7 +87,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, @@ -100,7 +99,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, @@ -112,14 +111,14 @@ struct pack_tables_avx512_16 { 30, 31, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array permute2 = [] { + alignas(64) inline static constexpr std::array permute2 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 1, 3, 5, 7, 8, 10, 12, 14, -1, 17, 19, 21, @@ -131,7 +130,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 1, 3, 4, 6, -1, 9, 11, 12, 14, -1, 17, 19, @@ -143,7 +142,7 @@ struct pack_tables_avx512_16 { 46, -1, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 1, -1, 4, -1, 7, 8, 10, 11, 13, 14, -1, 17, @@ -155,7 +154,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, @@ -167,7 +166,7 @@ struct pack_tables_avx512_16 { 37, 38, 40, 41 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 1, 2, -1, -1, 6, 7, 8, -1, 11, 12, 13, -1, @@ -179,7 +178,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, @@ -191,7 +190,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -206,10 +205,10 @@ struct pack_tables_avx512_16 { // clang-format on }(); - alignas(64) inline static constexpr std::array permute3 = [] { + alignas(64) inline static constexpr std::array permute3 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ -1, 1, 3, 5, 7, 8, 10, 12, 14, -1, 17, 19, @@ -221,7 +220,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ -1, 1, 3, 4, 6, -1, 9, 11, 12, 14, -1, 17, @@ -233,7 +232,7 @@ struct pack_tables_avx512_16 { 44, 46, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ -1, 1, 2, 4, 5, 7, 8, 10, 11, 13, 14, -1, @@ -245,7 +244,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ -1, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, @@ -257,14 +256,14 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift1 = [] { + alignas(64) inline static constexpr std::array shift1 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 0, 2, 4, 6, 0, 1, 3, 5, 7, 0, 2, 4, @@ -276,7 +275,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 0, 4, 0, 2, 6, 0, 4, 0, 2, 6, 0, 4, @@ -288,7 +287,7 @@ struct pack_tables_avx512_16 { 2, 6, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 0, 6, 1, 7, 2, 0, 3, 0, 4, 0, 5, 0, @@ -300,7 +299,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 12, 8, 4, 12, 8, 4, 12, 8, 4, 12, 8, 4, @@ -312,7 +311,7 @@ struct pack_tables_avx512_16 { 8, 4, 12, 8 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 0, 0, 7, 4, 1, 0, 0, 5, 2, 0, 0, 6, @@ -324,7 +323,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 14, 12, 10, 8, 6, 4, 2, 14, 12, 10, 8, 6, @@ -336,7 +335,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, @@ -348,14 +347,14 @@ struct pack_tables_avx512_16 { 2, 1, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift2 = [] { + alignas(64) inline static constexpr std::array shift2 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 9, 11, 13, 15, 8, 10, 12, 14, 8, 9, 11, 13, @@ -367,7 +366,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 10, 14, 8, 12, 0, 10, 14, 8, 12, 0, 10, 14, @@ -379,7 +378,7 @@ struct pack_tables_avx512_16 { 12, 0, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 11, 8, 12, 8, 13, 8, 14, 9, 15, 10, 8, 11, @@ -391,7 +390,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 12) { - return std::array{ + return std::array{ 0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8, @@ -403,7 +402,7 @@ struct pack_tables_avx512_16 { 4, 8, 0, 4 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 13, 10, 0, 0, 14, 11, 8, 0, 15, 12, 9, 0, @@ -415,7 +414,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 14) { - return std::array{ + return std::array{ 0, 2, 4, 6, 8, 10, 12, 0, 2, 4, 6, 8, @@ -427,7 +426,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 15) { - return std::array{ + return std::array{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -439,14 +438,14 @@ struct pack_tables_avx512_16 { 13, 14, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); - alignas(64) inline static constexpr std::array shift3 = [] { + alignas(64) inline static constexpr std::array shift3 = [] { // clang-format off if constexpr (N == 9) { - return std::array{ + return std::array{ 0, 7, 5, 3, 1, 8, 6, 4, 2, 0, 7, 5, @@ -458,7 +457,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 10) { - return std::array{ + return std::array{ 0, 6, 2, 8, 4, 0, 6, 2, 8, 4, 0, 6, @@ -470,7 +469,7 @@ struct pack_tables_avx512_16 { 8, 4, -1, -1 }; } else if constexpr (N == 11) { - return std::array{ + return std::array{ 0, 5, 10, 4, 9, 3, 8, 2, 7, 1, 6, 0, @@ -482,7 +481,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } else if constexpr (N == 13) { - return std::array{ + return std::array{ 0, 3, 6, 9, 12, 2, 5, 8, 11, 1, 4, 7, @@ -494,7 +493,7 @@ struct pack_tables_avx512_16 { -1, -1, -1, -1 }; } - return std::array{}; + return std::array{}; // clang-format on }(); @@ -529,127 +528,136 @@ struct pack_tables_avx512_16 { // clang-format on } - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v || std::is_same_v)) struct pack_tables_avx512_24 { private: struct word_plan { - int32_t left_index1 = -1; - int32_t left_index2 = -1; - int32_t right_index = -1; - uint32_t left_shift1 = 32; - uint32_t left_shift2 = 32; - uint32_t right_shift = 32; + i32 left_index1 = -1; + i32 left_index2 = -1; + i32 right_index = -1; + u32 left_shift1 = 32; + u32 left_shift2 = 32; + u32 right_shift = 32; }; - static constexpr word_plan create_plan(const uint32_t idx) { + static constexpr word_plan create_plan(const u32 idx) { word_plan plan{}; - const uint32_t word_start = idx * 32u; - const uint32_t word_end = word_start + 32u; + const u32 word_start = idx * 32u; + const u32 word_end = word_start + 32u; - uint32_t left_slot = 0; - for (uint32_t input_lane = 0; input_lane < 16; ++input_lane) { - const uint32_t input_start = input_lane * BIT_WIDTH; - const uint32_t input_end = input_start + BIT_WIDTH; + u32 left_slot = 0; + for (u32 input_lane = 0; input_lane < 16; ++input_lane) { + const u32 input_start = input_lane * BIT_WIDTH; + const u32 input_end = input_start + BIT_WIDTH; - const uint32_t overlap_start = std::max(word_start, input_start); - const uint32_t overlap_end = std::min(word_end, input_end); + const u32 overlap_start = std::max(word_start, input_start); + const u32 overlap_end = std::min(word_end, input_end); if (overlap_start >= overlap_end) { continue; } - const auto output_bit = static_cast(overlap_start - word_start); - const auto input_bit = static_cast(overlap_start - input_start); - const int32_t delta = output_bit - input_bit; + const auto output_bit = static_cast(overlap_start - word_start); + const auto input_bit = static_cast(overlap_start - input_start); + const i32 delta = output_bit - input_bit; if (delta >= 0) { if (left_slot == 0) { - plan.left_index1 = static_cast(input_lane); - plan.left_shift1 = static_cast(delta); + plan.left_index1 = static_cast(input_lane); + plan.left_shift1 = static_cast(delta); ++left_slot; } else { - plan.left_index2 = static_cast(input_lane); - plan.left_shift2 = static_cast(delta); + plan.left_index2 = static_cast(input_lane); + plan.left_shift2 = static_cast(delta); } } else { - plan.right_index = static_cast(input_lane); - plan.right_shift = static_cast(-delta); + plan.right_index = static_cast(input_lane); + plan.right_shift = static_cast(-delta); } } return plan; } - static inline constexpr std::array word_plans = [] { + static constexpr std::array word_plans = [] { std::array plans{}; - for (uint32_t i = 0; i < 16; ++i) { + for (u32 i = 0; i < 16; ++i) { plans[i] = create_plan(i); } return plans; }(); template - [[gnu::always_inline]] static constexpr std::array make_table(Getter getter) { + static __always_inline constexpr std::array make_table(Getter getter) { std::array values{}; - for (uint32_t i = 0; i < 16; ++i) { + for (u32 i = 0; i < 16; ++i) { values[i] = getter(word_plans[i]); } return values; } - alignas(64) static inline constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); + alignas(64) static constexpr auto permute1 = make_table([](const word_plan& p) { + return p.left_index1; + }); - alignas(64) static inline constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); + alignas(64) static constexpr auto permute2 = make_table([](const word_plan& p) { + return p.left_index2; + }); - alignas(64) static inline constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); + alignas(64) static constexpr auto permute3 = make_table([](const word_plan& p) { + return p.right_index; + }); - alignas(64) static inline constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); + alignas(64) static constexpr auto shift1 = make_table( + [](const word_plan& p) { return p.left_shift1; }); - alignas(64) static inline constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); + alignas(64) static constexpr auto shift2 = make_table( + [](const word_plan& p) { return p.left_shift2; }); - alignas(64) static inline constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); + alignas(64) static constexpr auto shift3 = make_table( + [](const word_plan& p) { return p.right_shift; }); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && (std::is_same_v || std::is_same_v || std::is_same_v)) struct unpack_tables_avx512_8 { private: - alignas(64) inline static constexpr std::array permute1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array permute1 = [] { + std::array table{}; std::ranges::fill(table, -1); for (size_t entry = 0; entry < 64; ++entry) { const size_t bit_start = entry * BIT_WIDTH; const size_t first_byte = bit_start / 8; - table[entry] = static_cast(first_byte); + table[entry] = static_cast(first_byte); } return table; }(); - alignas(64) inline static constexpr std::array permute2 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array permute2 = [] { + std::array table{}; std::ranges::fill(table, -1); for (size_t entry = 0; entry < 64; ++entry) { @@ -658,55 +666,55 @@ struct unpack_tables_avx512_8 { const size_t bit_offset = bit_start % 8; if (bit_offset + BIT_WIDTH > 8) { - table[entry] = static_cast(first_byte + 1); + table[entry] = static_cast(first_byte + 1); } } return table; }(); - alignas(64) inline static constexpr std::array shift1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift1 = [] { + std::array table{}; for (size_t entry = 0; entry < 64; ++entry) { const size_t bit_start = entry * BIT_WIDTH; const size_t bit_offset = bit_start % 8; - table[entry] = static_cast(bit_offset); + table[entry] = static_cast(bit_offset); } return table; }(); - alignas(64) inline static constexpr std::array shift2 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift2 = [] { + std::array table{}; for (size_t entry = 0; entry < 64; ++entry) { const size_t bit_start = entry * BIT_WIDTH; const size_t bit_offset = bit_start % 8u; const size_t spill_bits = (bit_offset + BIT_WIDTH > 8u) ? (bit_offset + BIT_WIDTH - 8u) : 0u; - table[entry] = spill_bits ? static_cast(8 - bit_offset) : 0; + table[entry] = spill_bits ? static_cast(8 - bit_offset) : 0; } return table; }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; -template +template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && (std::is_same_v || std::is_same_v || std::is_same_v)) struct unpack_tables_avx512_16 { private: - alignas(64) inline static constexpr std::array permute1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array permute1 = [] { + std::array table{}; std::ranges::fill(table, -1); for (size_t entry = 0; entry < 32; ++entry) { @@ -714,15 +722,15 @@ struct unpack_tables_avx512_16 { const size_t first_byte = bit_start / 8; const size_t base = entry * 2; - table[base] = static_cast(first_byte); - table[base + 1] = static_cast(first_byte + 1); + table[base] = static_cast(first_byte); + table[base + 1] = static_cast(first_byte + 1); } return table; }(); - alignas(64) inline static constexpr std::array permute2 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array permute2 = [] { + std::array table{}; std::ranges::fill(table, -1); for (size_t entry = 0; entry < 32; ++entry) { @@ -732,56 +740,56 @@ struct unpack_tables_avx512_16 { const size_t base = entry * 2; if (bit_offset + BIT_WIDTH > 16) { - table[base] = static_cast(first_byte + 2); + table[base] = static_cast(first_byte + 2); } } return table; }(); - alignas(64) inline static constexpr std::array shift1 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift1 = [] { + std::array table{}; for (size_t entry = 0; entry < 32; ++entry) { const size_t bit_start = entry * BIT_WIDTH; const size_t bit_offset = bit_start % 8u; // Right-shift the 16-bit chunk so the value starts at bit 0. - table[entry] = static_cast(bit_offset); + table[entry] = static_cast(bit_offset); } return table; }(); - alignas(64) inline static constexpr std::array shift2 = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift2 = [] { + std::array table{}; for (size_t entry = 0; entry < 32; ++entry) { const size_t bit_start = entry * BIT_WIDTH; const size_t bit_offset = bit_start % 8u; const size_t spill_bits = (bit_offset + BIT_WIDTH > 16u) ? (bit_offset + BIT_WIDTH - 16u) : 0u; // Move spill bits from byte3 to their final bit positions before merge. - table[entry] = spill_bits ? static_cast(16u - bit_offset) : 0; + table[entry] = spill_bits ? static_cast(16u - bit_offset) : 0; } return table; }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && (std::is_same_v || std::is_same_v || std::is_same_v)) struct unpack_tables_avx512_24 { private: - alignas(64) inline static constexpr std::array permute = [] { - std::array table{}; + alignas(64) inline static constexpr std::array permute = [] { + std::array table{}; std::ranges::fill(table, -1); for (size_t entry = 0; entry < 16; ++entry) { @@ -794,29 +802,28 @@ struct unpack_tables_avx512_24 { const size_t base = entry * 4; for (size_t byte = first_byte; byte <= last_byte; ++byte) { - table[base + (byte - first_byte)] = static_cast(byte); + table[base + (byte - first_byte)] = static_cast(byte); } } return table; }(); - alignas(64) inline static constexpr std::array shift = [] { - std::array table{}; + alignas(64) inline static constexpr std::array shift = [] { + std::array table{}; for (size_t entry = 0; entry < 16; ++entry) { const size_t bit_start = entry * BIT_WIDTH; - table[entry] = static_cast(32u - BIT_WIDTH - (bit_start % 8u)); + table[entry] = static_cast(32u - BIT_WIDTH - (bit_start % 8u)); } return table; }(); public: - [[gnu::always_inline]] static inline Vec get_permute() { return load_table(permute); } - [[gnu::always_inline]] static inline Vec get_shift() { return load_table(shift); } + static __always_inline Vec get_permute() { return load_table(permute); } + static __always_inline Vec get_shift() { return load_table(shift); } }; - -} // namespace pernix::internal +} // namespace pernix::internal #endif // PERNIX_AVX512VBMI_TABLES_H diff --git a/src/internal/pernix/x86/avx512vbmi/unpacking.h b/src/internal/pernix/x86/avx512vbmi/unpacking.h new file mode 100644 index 0000000..e66f9ec --- /dev/null +++ b/src/internal/pernix/x86/avx512vbmi/unpacking.h @@ -0,0 +1,500 @@ +#ifndef PERNIX_AVX512VBMI_UNPACKING_H +#define PERNIX_AVX512VBMI_UNPACKING_H + +#include +#include + +namespace pernix::internal { + namespace m128 { + constexpr __mmask16 kAlternateByteMask16 = 0xAAAAULL; + +__always_inline static __m128i _mm_srlv_epi8(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi16(0x00ff); + const __m128i low_half = _mm_srlv_epi16(_mm_and_si128(mask, a), _mm_and_si128(mask, count)); + const __m128i high_half = _mm_srlv_epi16(a, _mm_srli_epi16(count, 8)); + return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); + } + +__always_inline static __m128i _mm_sllv_epi8(const __m128i a, const __m128i count) { + const __m128i mask = _mm_set1_epi16(0xff00); + const __m128i low_half = _mm_sllv_epi16(a, _mm_andnot_si128(mask, count)); + const __m128i high_half = _mm_sllv_epi16(_mm_and_si128(mask, a), _mm_srli_epi16(count, 8)); + return _mm_mask_blend_epi8(kAlternateByteMask16, low_half, high_half); + } + +__always_inline static __m128i _mm_slli_epi8(const __m128i a, const i8 imm8) { + return _mm_sllv_epi8(a, _mm_set1_epi8(imm8)); + } + +__always_inline static __m128i _mm_srli_epi8(const __m128i a, const int imm8) { + const __m128i lo_mask = _mm_set1_epi16(0x00ff); + const __m128i hi_mask = _mm_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m128i lo = _mm_srl_epi16(_mm_and_si128(a, lo_mask), shift); + const __m128i hi = _mm_and_si128(_mm_srl_epi16(a, shift), hi_mask); + + return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); + } + +__always_inline static __m128i _mm_srai_epi8(const __m128i a, const i8 imm8) { + const __m128i lo_mask = _mm_set1_epi16(0x00ff); + const __m128i hi_mask = _mm_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m128i hi = _mm_and_si128(_mm_sra_epi16(a, shift), hi_mask); + + const __m128i lo_as_hi = _mm_slli_epi16(_mm_and_si128(a, lo_mask), 8); + const __m128i lo = _mm_and_si128(_mm_srli_epi16(_mm_sra_epi16(lo_as_hi, shift), 8), lo_mask); + + return _mm_mask_blend_epi8(kAlternateByteMask16, lo, hi); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m128i mm_unpack_epi8_avx512vbmi_1to8(const __m128i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + if constexpr (BIT_WIDTH == 1) { + const auto value = static_cast<__mmask16>(_mm_cvtsi128_si64(input)); + const __m128i source = _mm_movm_epi8(value); + const __m128i unpacked = _mm_abs_epi8(source); + return unpacked; + } else if constexpr (BIT_WIDTH == 2) { + __m128i values_shift0 = input; + __m128i values_shift2 = _mm_srli_epi16(values_shift0, 2); + const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); + const __m128i values_shift6 = _mm_srli_epi16(values_shift0, 6); + + __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift2); + values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift2); + values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); + + interleave_tmp = _mm_unpacklo_epi8(values_shift4, values_shift6); + values_shift2 = _mm_unpackhi_epi8(values_shift4, values_shift6); + values_shift2 = _mm_unpacklo_epi64(interleave_tmp, values_shift2); + + interleave_tmp = _mm_unpacklo_epi16(values_shift0, values_shift2); + values_shift0 = _mm_unpackhi_epi16(values_shift0, values_shift2); + values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); + values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); + + values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0303)); + + return values_shift0; + } else if constexpr (BIT_WIDTH == 4) { + __m128i values_shift0 = input; + const __m128i values_shift4 = _mm_srli_epi16(values_shift0, 4); + + const __m128i interleave_tmp = _mm_unpacklo_epi8(values_shift0, values_shift4); + values_shift0 = _mm_unpackhi_epi8(values_shift0, values_shift4); + values_shift0 = _mm_unpacklo_epi64(interleave_tmp, values_shift0); + values_shift0 = _mm_shuffle_epi32(values_shift0, 0xD8); + + values_shift0 = _mm_and_si128(values_shift0, _mm_set1_epi16(0x0F0F)); + + return values_shift0; + } else { + using tables = unpack_tables_avx512_8; + + const __m128i permuted1 = _mm_permutexvar_epi8(tables::get_permute1(), input); + const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); + + const __m128i shifted1 = _mm_srlv_epi8(permuted1, tables::get_shift1()); + const __m128i shifted2 = _mm_sllv_epi8(permuted2, tables::get_shift2()); + + const __mmask16 spill_mask = _mm_cmpneq_epi8_mask(tables::get_shift2(), _mm_setzero_si128()); + __m128i combined = _mm_or_si128(shifted1, _mm_maskz_mov_epi8(spill_mask, shifted2)); + + constexpr u32 shift = 8 - BIT_WIDTH; + combined = _mm_slli_epi8(combined, shift); + if (SIGN_VALUES) { + combined = _mm_srai_epi8(combined, shift); + } else { + combined = _mm_srli_epi8(combined, shift); + } + + return combined; + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m128i mm_unpack_epi16_avx512vbmi_9to16(const __m128i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = unpack_tables_avx512_16; + + const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute1(), input); + + __m128i shifted = _mm_srlv_epi16(permuted, tables::get_shift1()); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m128i permuted2 = _mm_permutexvar_epi8(tables::get_permute2(), input); + const __m128i shifted2 = _mm_sllv_epi16(permuted2, tables::get_shift2()); + shifted = _mm_or_si128(shifted, shifted2); + } + + constexpr u32 shift = 16 - BIT_WIDTH; + shifted = _mm_slli_epi16(shifted, shift); + if (SIGN_VALUES) { + shifted = _mm_srai_epi16(shifted, shift); + } else { + shifted = _mm_srli_epi16(shifted, shift); + } + + return shifted; + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m128i mm_unpack_epi32_avx512vbmi_17to24(const __m128i &input) { + using tables = unpack_tables_avx512_24; + + const __m128i permuted = _mm_permutexvar_epi8(tables::get_permute(), input); + + constexpr u32 shift = 32 - BIT_WIDTH; + __m128i shifted = _mm_sllv_epi32(permuted, tables::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm_srai_epi32(shifted, shift); + } else { + shifted = _mm_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace m128 + + namespace m256 { + constexpr __mmask32 kAlternateByteMask32 = 0xAAAAAAAAULL; + +__always_inline static __m256i _mm256_srlv_epi8(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi16(0x00ff); + const __m256i low_half = _mm256_srlv_epi16(_mm256_and_si256(mask, a), _mm256_and_si256(mask, count)); + const __m256i high_half = _mm256_srlv_epi16(a, _mm256_srli_epi16(count, 8)); + return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); + } + +__always_inline static __m256i _mm256_sllv_epi8(const __m256i a, const __m256i count) { + const __m256i mask = _mm256_set1_epi16(0xff00); + const __m256i low_half = _mm256_sllv_epi16(a, _mm256_andnot_si256(mask, count)); + const __m256i high_half = _mm256_sllv_epi16(_mm256_and_si256(mask, a), _mm256_srli_epi16(count, 8)); + return _mm256_mask_blend_epi8(kAlternateByteMask32, low_half, high_half); + } + +__always_inline static __m256i _mm256_slli_epi8(const __m256i a, const i8 imm8) { + return _mm256_sllv_epi8(a, _mm256_set1_epi8(imm8)); + } + +__always_inline static __m256i _mm256_srli_epi8(const __m256i a, const i8 imm8) { + const __m256i lo_mask = _mm256_set1_epi16(0x00ff); + const __m256i hi_mask = _mm256_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m256i lo = _mm256_srl_epi16(_mm256_and_si256(a, lo_mask), shift); + const __m256i hi = _mm256_and_si256(_mm256_srl_epi16(a, shift), hi_mask); + + return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); + } + +__always_inline static __m256i _mm256_srai_epi8(const __m256i a, const i8 imm8) { + const __m256i lo_mask = _mm256_set1_epi16(0x00ff); + const __m256i hi_mask = _mm256_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m256i hi = _mm256_and_si256(_mm256_sra_epi16(a, shift), hi_mask); + + const __m256i lo_as_hi = _mm256_slli_epi16(_mm256_and_si256(a, lo_mask), 8); + const __m256i lo = _mm256_and_si256(_mm256_srli_epi16(_mm256_sra_epi16(lo_as_hi, shift), 8), lo_mask); + + return _mm256_mask_blend_epi8(kAlternateByteMask32, lo, hi); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m256i mm256_unpack_epi8_avx512vbmi_1to8(const __m256i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + if constexpr (BIT_WIDTH == 1) { + const auto value = static_cast<__mmask32>(_mm_cvtsi128_si64(_mm256_castsi256_si128(input))); + const __m256i source = _mm256_movm_epi8(value); + const __m256i unpacked = _mm256_abs_epi8(source); + return unpacked; + } else if constexpr (BIT_WIDTH == 2) { + __m256i values_shift0 = input; + __m256i values_shift2 = _mm256_srli_epi16(values_shift0, 2); + const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); + const __m256i values_shift6 = _mm256_srli_epi16(values_shift0, 6); + + __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift2); + values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift2); + values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); + + interleave_tmp = _mm256_unpacklo_epi8(values_shift4, values_shift6); + values_shift2 = _mm256_unpackhi_epi8(values_shift4, values_shift6); + values_shift2 = _mm256_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); + + interleave_tmp = _mm256_unpacklo_epi16(values_shift0, values_shift2); + values_shift0 = _mm256_unpackhi_epi16(values_shift0, values_shift2); + values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); + values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); + + values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0303)); + + return values_shift0; + } else if constexpr (BIT_WIDTH == 4) { + __m256i values_shift0 = input; + const __m256i values_shift4 = _mm256_srli_epi16(values_shift0, 4); + + __m256i interleave_tmp = _mm256_unpacklo_epi8(values_shift0, values_shift4); + values_shift0 = _mm256_unpackhi_epi8(values_shift0, values_shift4); + values_shift0 = _mm256_shuffle_i32x4(interleave_tmp, values_shift0, 0b00); + values_shift0 = _mm256_shuffle_i32x4(values_shift0, values_shift0, 0b00); + + values_shift0 = _mm256_and_si256(values_shift0, _mm256_set1_epi16(0x0F0F)); + + return values_shift0; + } else { + using tables = unpack_tables_avx512_8; + + const __m256i permuted1 = _mm256_permutexvar_epi8(tables::get_permute1(), input); + const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); + + const __m256i shifted1 = _mm256_srlv_epi8(permuted1, tables::get_shift1()); + const __m256i shifted2 = _mm256_sllv_epi8(permuted2, tables::get_shift2()); + + const __mmask32 spill_mask = _mm256_cmpneq_epi8_mask(tables::get_shift2(), _mm256_setzero_si256()); + __m256i combined = _mm256_or_si256(shifted1, _mm256_maskz_mov_epi8(spill_mask, shifted2)); + + constexpr u32 shift = 8 - BIT_WIDTH; + combined = _mm256_slli_epi8(combined, shift); + if (SIGN_VALUES) { + combined = _mm256_srai_epi8(combined, shift); + } else { + combined = _mm256_srli_epi8(combined, shift); + } + + return combined; + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m256i mm256_unpack_epi16_avx512vbmi_9to16(const __m256i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = unpack_tables_avx512_16; + + const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute1(), input); + + __m256i shifted = _mm256_srlv_epi16(permuted, tables::get_shift1()); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m256i permuted2 = _mm256_permutexvar_epi8(tables::get_permute2(), input); + const __m256i shifted2 = _mm256_sllv_epi16(permuted2, tables::get_shift2()); + shifted = _mm256_or_si256(shifted, shifted2); + } + + constexpr u32 shift = 16 - BIT_WIDTH; + shifted = _mm256_slli_epi16(shifted, shift); + if (SIGN_VALUES) { + shifted = _mm256_srai_epi16(shifted, shift); + } else { + shifted = _mm256_srli_epi16(shifted, shift); + } + + return shifted; + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m256i mm256_unpack_epi32_avx512vbmi_17to24(const __m256i &input) { + using tables = unpack_tables_avx512_24; + + const __m256i permuted = _mm256_permutexvar_epi8(tables::get_permute(), input); + + constexpr u32 shift = 32 - BIT_WIDTH; + __m256i shifted = _mm256_sllv_epi32(permuted, tables::get_shift()); + if constexpr (SIGN_VALUES && BIT_WIDTH > 1) { + shifted = _mm256_srai_epi32(shifted, shift); + } else { + shifted = _mm256_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace m256 + + namespace m512 { + constexpr __mmask64 kAlternateByteMask64 = 0xAAAAAAAAAAAAAAAAULL; + +__always_inline static __m512i _mm512_srlv_epi8(const __m512i a, const __m512i count) { + const __m512i mask = _mm512_set1_epi16(0x00ff); + const __m512i low_half = _mm512_srlv_epi16(_mm512_and_si512(mask, a), _mm512_and_si512(mask, count)); + const __m512i high_half = _mm512_srlv_epi16(a, _mm512_srli_epi16(count, 8)); + return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); + } + +__always_inline static __m512i _mm512_sllv_epi8(const __m512i a, const __m512i count) { + const __m512i mask = _mm512_set1_epi16(0xff00); + const __m512i low_half = _mm512_sllv_epi16(a, _mm512_andnot_si512(mask, count)); + const __m512i high_half = _mm512_sllv_epi16(_mm512_and_si512(mask, a), _mm512_srli_epi16(count, 8)); + return _mm512_mask_blend_epi8(kAlternateByteMask64, low_half, high_half); + } + +__always_inline static __m512i _mm512_slli_epi8(const __m512i a, const i8 imm8) { + return _mm512_sllv_epi8(a, _mm512_set1_epi8(imm8)); + } + +__always_inline static __m512i _mm512_srli_epi8(const __m512i a, const i8 imm8) { + const __m512i lo_mask = _mm512_set1_epi16(0x00ff); + const __m512i hi_mask = _mm512_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m512i lo = _mm512_srl_epi16(_mm512_and_si512(a, lo_mask), shift); + const __m512i hi = _mm512_and_si512(_mm512_srl_epi16(a, shift), hi_mask); + + return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); + } + +__always_inline static __m512i _mm512_srai_epi8(const __m512i a, const i8 imm8) { + const __m512i lo_mask = _mm512_set1_epi16(0x00ff); + const __m512i hi_mask = _mm512_set1_epi16(0xff00); + const __m128i shift = _mm_cvtsi32_si128(imm8); + + const __m512i hi = _mm512_and_si512(_mm512_sra_epi16(a, shift), hi_mask); + + const __m512i lo_as_hi = _mm512_slli_epi16(_mm512_and_si512(a, lo_mask), 8); + const __m512i lo = _mm512_and_si512(_mm512_srli_epi16(_mm512_sra_epi16(lo_as_hi, shift), 8), lo_mask); + + return _mm512_mask_blend_epi8(kAlternateByteMask64, lo, hi); + } + + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline __m512i mm512_unpack_epi8_avx512vbmi_1to8(const __m512i &input) { + if constexpr (BIT_WIDTH == 8) { + return input; + } else { + if constexpr (BIT_WIDTH == 1) { + const auto value = static_cast<__mmask64>(_mm_cvtsi128_si64(_mm512_castsi512_si128(input))); + const __m512i source = _mm512_movm_epi8(value); + const __m512i unpacked = _mm512_abs_epi8(source); + return unpacked; + } else if constexpr (BIT_WIDTH == 2) { + __m512i values_shift0 = input; + __m512i values_shift2 = _mm512_srli_epi16(values_shift0, 2); + const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); + const __m512i values_shift6 = _mm512_srli_epi16(values_shift0, 6); + + __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift2); + values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift2); + values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0b00000000); + + interleave_tmp = _mm512_unpacklo_epi8(values_shift4, values_shift6); + values_shift2 = _mm512_unpackhi_epi8(values_shift4, values_shift6); + values_shift2 = _mm512_shuffle_i32x4(interleave_tmp, values_shift2, 0b00000000); + + interleave_tmp = _mm512_unpacklo_epi16(values_shift0, values_shift2); + values_shift0 = _mm512_unpackhi_epi16(values_shift0, values_shift2); + values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x88); + values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); + + values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0303)); + + return values_shift0; + } else if constexpr (BIT_WIDTH == 4) { + __m512i values_shift0 = input; + const __m512i values_shift4 = _mm512_srli_epi16(values_shift0, 4); + + __m512i interleave_tmp = _mm512_unpacklo_epi8(values_shift0, values_shift4); + values_shift0 = _mm512_unpackhi_epi8(values_shift0, values_shift4); + values_shift0 = _mm512_shuffle_i32x4(interleave_tmp, values_shift0, 0x44); + values_shift0 = _mm512_shuffle_i32x4(values_shift0, values_shift0, 0xD8); + + values_shift0 = _mm512_and_si512(values_shift0, _mm512_set1_epi16(0x0F0F)); + + return values_shift0; + } else { + using tables = unpack_tables_avx512_8; + + const __m512i permuted1 = _mm512_permutexvar_epi8(tables::get_permute1(), input); + const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); + + const __m512i shifted1 = _mm512_srlv_epi8(permuted1, tables::get_shift1()); + const __m512i shifted2 = _mm512_sllv_epi8(permuted2, tables::get_shift2()); + + const __mmask64 spill_mask = _mm512_cmpneq_epi8_mask(tables::get_shift2(), _mm512_setzero_si512()); + __m512i combined = _mm512_or_si512(shifted1, _mm512_maskz_mov_epi8(spill_mask, shifted2)); + + constexpr u32 shift = 8 - BIT_WIDTH; + combined = _mm512_slli_epi8(combined, shift); + if (SIGN_VALUES) { + combined = _mm512_srai_epi8(combined, shift); + } else { + combined = _mm512_srli_epi8(combined, shift); + } + + return combined; + } + } + } + + template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline __m512i mm512_unpack_epi16_avx512vbmi_9to16(const __m512i &input) { + if constexpr (BIT_WIDTH == 16) { + return input; + } else { + using tables = unpack_tables_avx512_16; + + const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute1(), input); + __m512i shifted = _mm512_srlv_epi16(permuted, tables::get_shift1()); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const __m512i permuted2 = _mm512_permutexvar_epi8(tables::get_permute2(), input); + const __m512i shifted2 = _mm512_sllv_epi16(permuted2, tables::get_shift2()); + shifted = _mm512_or_si512(shifted, shifted2); + } + + constexpr u32 shift = 16 - BIT_WIDTH; + shifted = _mm512_slli_epi16(shifted, shift); + if (SIGN_VALUES) { + shifted = _mm512_srai_epi16(shifted, shift); + } else { + shifted = _mm512_srli_epi16(shifted, shift); + } + + return shifted; + } + } + + template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline __m512i mm512_unpack_epi32_avx512vbmi_17to24(const __m512i &input) { + using tables = unpack_tables_avx512_24; + + const __m512i permuted = _mm512_permutexvar_epi8(tables::get_permute(), input); + __m512i shifted = _mm512_sllv_epi32(permuted, tables::get_shift()); + + constexpr u32 shift = 32 - BIT_WIDTH; + if constexpr (SIGN_VALUES) { + shifted = _mm512_srai_epi32(shifted, shift); + } else { + shifted = _mm512_srli_epi32(shifted, shift); + } + + return shifted; + } + } // namespace m512 +} // namespace pernix::internal + +#endif // PERNIX_AVX512VBMI_UNPACKING_H diff --git a/src/internal/pernix/x86/bmi2/bmi2_compression.h b/src/internal/pernix/x86/bmi2/bmi2_compression.h new file mode 100644 index 0000000..e165c3a --- /dev/null +++ b/src/internal/pernix/x86/bmi2/bmi2_compression.h @@ -0,0 +1,323 @@ +#ifndef PERNIX_BMI2_COMPRESSION_H +#define PERNIX_BMI2_COMPRESSION_H + +#include +#include +#include + +#include +#include +#include + +namespace pernix { + namespace internal { + /** +* @brief Build the masks and shift constants used by the BMI2 packers. +* +* @tparam BIT_WIDTH bit width per packed value. +* @return std::tuple mask tuple used by the BMI2 helpers. +*/ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) + static constexpr std::tuple pack_avx2_bmi2_constants() { + u32 mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; + u64 pext_mask; + u16 shift1 = BIT_WIDTH * 4; + u16 shift2 = 64 - shift1; + + if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { + pext_mask = 0x0101010101010101ULL * mask; + } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { + pext_mask = 0x0001000100010001ULL * mask; + } else { + pext_mask = 0x0000000100000001ULL * mask; + } + + return { + mask, + pext_mask, + shift1, + shift2, + }; + } + + /** +* @brief Pack four 32-bit values with BMI2 extract instructions. +* +* @tparam BIT_WIDTH bit width per packed value. +* @param input SIMD register containing four quantized values. +* @return __m128i packed bitstream in the low bytes of the result. +*/ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 32) + static inline auto mm_pack_epi32_bmi2(const __m128i &input) -> __m128i { + const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); + + if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 16) { + const __m128i packed = _mm_packs_epi32(input, _mm_setzero_si128()); + const u64 value = _pext_u64(_mm_extract_epi64(packed, 0), pext_mask); + + const __m128i result = _mm_set_epi64x(0, value); + return result; + } else { + alignas(16) u64 values[2]; + values[0] = _pext_u64(_mm_extract_epi64(input, 0), pext_mask); + + const u64 temp_combined = _pext_u64(_mm_extract_epi64(input, 1), pext_mask); + values[1] = temp_combined >> shift2; + values[0] |= (temp_combined << shift1); + + const __m128i result = _mm_set_epi64x(static_cast(values[1]), static_cast(values[0])); + return result; + } + } + + /** +* @brief Pack eight 32-bit values with BMI2 extract instructions. +* +* @tparam BIT_WIDTH bit width per packed value. +* @param input SIMD register containing eight quantized values. +* @return __m256i packed bitstream in the low bytes of the result. +*/ + template + requires(BIT_WIDTH > 0 && BIT_WIDTH <= 24) + static inline auto mm256_pack_epi32_bmi2(const __m256i &input) -> __m256i { + const auto [mask, pext_mask, shift1, shift2] = pack_avx2_bmi2_constants(); + + if constexpr (BIT_WIDTH > 0 && BIT_WIDTH <= 8) { + const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); + const __m256i permuted = _mm256_permute4x64_epi64(packed16, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i packed8 = _mm256_packs_epi16(permuted, _mm256_setzero_si256()); + const u64 value = _pext_u64(_mm256_extract_epi64(packed8, 0), pext_mask); + + const __m256i result = _mm256_setr_epi64x(static_cast(value), 0, 0, 0); + return result; + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + const __m256i packed16 = _mm256_packs_epi32(input, _mm256_setzero_si256()); + alignas(16) i64 values[2] = {}; + values[0] = _pext_u64(_mm256_extract_epi64(packed16, 0), pext_mask); + + const u64 temp_combined = _pext_u64(_mm256_extract_epi64(packed16, 2), pext_mask); + values[1] = temp_combined >> shift2; + if constexpr (BIT_WIDTH != 16) { + values[0] |= static_cast(temp_combined << shift1); + } + + const __m256i result = _mm256_setr_epi64x(values[0], values[1], 0, 0); + return result; + } else { + constexpr u32 chunk_bits = BIT_WIDTH * 2; // bits extracted per 64-bit lane + static_assert(chunk_bits < 64); + + constexpr u64 chunk_mask = (chunk_bits == 64) ? ~u64{0} : ((u64{1} << chunk_bits) - 1); + + const u64 x0 = _pext_u64(_mm256_extract_epi64(input, 0), pext_mask) & chunk_mask; + const u64 x1 = _pext_u64(_mm256_extract_epi64(input, 1), pext_mask) & chunk_mask; + const u64 x2 = _pext_u64(_mm256_extract_epi64(input, 2), pext_mask) & chunk_mask; + const u64 x3 = _pext_u64(_mm256_extract_epi64(input, 3), pext_mask) & chunk_mask; + + u64 out0 = 0; + u64 out1 = 0; + u64 out2 = 0; + + auto append_bits = [&](u64 value, u32 bit_offset) { + const u32 word = bit_offset >> 6; // / 64 + const u32 off = bit_offset & 63; // % 64 + + if (word == 0) { + out0 |= value << off; + if (off + chunk_bits > 64) { + out1 |= value >> (64 - off); + } + } else if (word == 1) { + out1 |= value << off; + if (off + chunk_bits > 64) { + out2 |= value >> (64 - off); + } + } else { + out2 |= value << off; + } + }; + + append_bits(x0, 0 * chunk_bits); + append_bits(x1, 1 * chunk_bits); + append_bits(x2, 2 * chunk_bits); + append_bits(x3, 3 * chunk_bits); + + return _mm256_setr_epi64x(static_cast(out0), static_cast(out1), + static_cast(out2), 0); + } + } + } // namespace internal + + /** +* @brief Compress a single 512-bit block using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + std::memset(output, 0, BLOCK_SIZE); + + const __m256 scale_v = _mm256_set1_ps(scale); +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256 source = _mm256_loadu_ps(input); + const __m256i quantized = internal::mm256_quantize_ps_epi32(source, scale_v); + const __m256i packed_input = internal::mm256_clamp_signed_epi32 < BIT_WIDTH > (quantized); + const __m256i packed = internal::mm256_pack_epi32_bmi2(packed_input); + std::memcpy(output, &packed, BIT_WIDTH); + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining) { + std::vector block_values(remaining); +#pragma GCC unroll 8 + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_ps_epi32(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); + } + + return 0; + } + + /** +* @brief Compress a single block of double values using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_block_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + std::memset(output, 0, BLOCK_SIZE); + + const __m256d scale_v = _mm256_set1_pd(scale); +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256d source1 = _mm256_loadu_pd(input); + const __m256d source2 = _mm256_loadu_pd(input + 4); + const __m128i quantized1 = internal::mm256_quantize_pd_epi32(source1, scale_v); + const __m128i quantized2 = internal::mm256_quantize_pd_epi32(source2, scale_v); + __m256i combined = _mm256_castsi128_si256(quantized1); + combined = _mm256_inserti128_si256(combined, quantized2, 1); + const __m256i packed = internal::mm256_pack_epi32_bmi2( + internal::mm256_clamp_signed_epi32 < BIT_WIDTH > (combined)); + std::memcpy(output, &packed, BIT_WIDTH); + input += 8; + output += BIT_WIDTH; + } + + if constexpr (remaining) { + std::vector block_values(remaining); +#pragma GCC unroll 8 + for (u32 i = 0; i < remaining; i++) { + block_values[i] = + static_cast(internal::clamp_signed_quantized < BIT_WIDTH > ( + internal::quantize_pd_epi64(input[i], scale))); + } + + internal::pack_epi32_fallback < BIT_WIDTH > (block_values, output); + } + return 0; + } + + /** +* @brief Compress multiple 512-bit blocks using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* +* @param input pointer to the start of the input float values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of 512-bit blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f32 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_bmi2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } + + /** +* @brief Compress multiple blocks of double values using AVX2 and BMI2 instructions. +* +* @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). +* @param input pointer to the start of the input double values. +* @param scale scaling factor used during quantization. +* @param output pointer to the output buffer where compressed bytes will be stored. +* @param blocks number of blocks to compress. +* @return int status code (0 for success). +* +* @note This function requires AVX2 and BMI2 support. +*/ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_compress_blocks_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const f64 *block_input = input; + u8 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_compress_block_bmi2(block_input, scale, block_output); + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_output += BLOCK_SIZE; + } + + return 0; + } +} // namespace pernix + +#endif // PERNIX_BMI2_COMPRESSION_H diff --git a/src/internal/pernix/x86/bmi2/bmi2_decompression.h b/src/internal/pernix/x86/bmi2/bmi2_decompression.h new file mode 100644 index 0000000..275a11a --- /dev/null +++ b/src/internal/pernix/x86/bmi2/bmi2_decompression.h @@ -0,0 +1,367 @@ +#ifndef PERNIX_BMI2_DECOMPRESSION_H +#define PERNIX_BMI2_DECOMPRESSION_H + +#include +#include + +#include +#include +#include +#include + +namespace pernix { + namespace internal { + /** + * @brief Sign-extend packed values after BMI2 expansion into 32-bit lanes. + * + * @tparam BIT_WIDTH original encoded bit width. + * @param source register containing unpacked values. + * @return __m128i sign-extended values. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m128i mm_sign_extend32(__m128i source) { + if constexpr (BIT_WIDTH == 1) { + // Keep 1-bit values as 0/1 to match fallback decoding semantics. + return source; + } + + constexpr u16 shift = 32 - BIT_WIDTH; + source = _mm_slli_epi32(source, shift); + return _mm_srai_epi32(source, shift); + } + + /** + * @brief Sign-extend packed values after BMI2 expansion into eight 32-bit lanes. + * + * @tparam BIT_WIDTH original encoded bit width. + * @param source register containing unpacked values. + * @return __m256i sign-extended values. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m256i mm256_sign_extend32(__m256i source) { + if constexpr (BIT_WIDTH == 1) { + // Keep 1-bit values as 0/1 to match fallback decoding semantics. + return source; + } + + constexpr u16 shift = 32 - BIT_WIDTH; + source = _mm256_slli_epi32(source, shift); + return _mm256_srai_epi32(source, shift); + } + + /** + * @brief Unpack four values from a BMI2-packed input buffer. + * + * @tparam BIT_WIDTH bit width per packed value. + * @tparam SIGN_VALUES whether to sign-extend the unpacked values. + * @param input pointer to the packed input buffer. + * @return __m128i unpacked values. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH > 0 && BIT_WIDTH <= 24) + __m128i mm_unpack_epi32_bmi2(const u8 * __restrict__ input) { + constexpr u32 mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; + constexpr std::size_t packed_bytes = (4 * BIT_WIDTH + 7) / 8; + + __m128i result; + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + constexpr u64 pdep_mask = 0x0101010101010101ULL * mask; + + u32 temp_value = 0; + std::memcpy(&temp_value, input, packed_bytes); + + const i32 value = _pdep_u32(temp_value, static_cast(pdep_mask)); + const __m128i source = _mm_insert_epi32(_mm_setzero_si128(), value, 0); + + result = _mm_cvtepi8_epi32(source); + } else if constexpr (BIT_WIDTH == 16) { + const __m128i source = _mm_loadu_si64(input); + result = _mm_cvtepi16_epi32(source); + } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { + constexpr u64 pdep_mask = 0x0001000100010001ULL * mask; + + u64 temp_value = 0; + std::memcpy(&temp_value, input, packed_bytes); + + const i64 value = _pdep_u64(temp_value, pdep_mask); + const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); + + result = _mm_cvtepi16_epi32(source); + } else { + constexpr u64 pdep_mask = 0x0000000100000001ULL * mask; + constexpr u32 shift1 = BIT_WIDTH * 2; + constexpr u32 shift2 = 64 - shift1; + + alignas(16) u64 temp_values[2]{}; + std::memcpy(temp_values, input, packed_bytes); + + alignas(16) i64 values[2]; + values[0] = _pdep_u64(temp_values[0], pdep_mask); + values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); + + result = _mm_set_epi64x(values[1], values[0]); + } + + if constexpr (SIGN_VALUES) { + result = internal::mm_sign_extend32(result); + } + return result; + } + + /** + * @brief Unpack eight values from a BMI2-packed input buffer. + * + * @tparam BIT_WIDTH bit width per packed value. + * @tparam SIGN_VALUES whether to sign-extend the unpacked values. + * @param input pointer to the packed input buffer. + * @return __m256i unpacked values. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) + __m256i mm256_unpack_epi32_bmi2(const u8 * __restrict__ input) { + constexpr u32 mask = BIT_WIDTH == 32 ? std::numeric_limits::max() : (1ULL << BIT_WIDTH) - 1U; + constexpr std::size_t packed_bytes = BIT_WIDTH; + + __m256i result; + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + constexpr u64 pdep_mask = 0x0101010101010101ULL * mask; + + u64 temp_value = 0; + std::memcpy(&temp_value, input, packed_bytes); + + const i64 value = _pdep_u64(temp_value, pdep_mask); + const __m128i source = _mm_insert_epi64(_mm_setzero_si128(), value, 0); + + result = _mm256_cvtepi8_epi32(source); + } else if constexpr (BIT_WIDTH == 16) { + const __m128i source = _mm_loadu_si128(reinterpret_cast(input)); + result = _mm256_cvtepi16_epi32(source); + } else if constexpr (BIT_WIDTH > 8 && BIT_WIDTH <= 16) { + constexpr u64 pdep_mask = 0x0001000100010001ULL * mask; + constexpr u64 shift1 = BIT_WIDTH * 4; + constexpr u64 shift2 = 64 - shift1; + + alignas(16) u64 temp_values[2]{}; + std::memcpy(temp_values, input, packed_bytes); + + alignas(16) i64 values[2]; + values[0] = _pdep_u64(temp_values[0], pdep_mask); + values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); + + const __m128i source = _mm_set_epi64x(values[1], values[0]); + result = _mm256_cvtepi16_epi32(source); + } else { + constexpr u64 pdep_mask = 0x0000000100000001ULL * mask; + constexpr u32 shift1 = BIT_WIDTH * 2; + constexpr u32 shift2 = 64 - shift1; + + alignas(16) u64 temp_values[4]{}; + std::memcpy(temp_values, input, packed_bytes); + + if constexpr ((BIT_WIDTH % 2) == 0) { + std::memcpy(temp_values + 2, input + BIT_WIDTH / 2, packed_bytes - BIT_WIDTH / 2); + } else { + constexpr u32 second_group_bit_offset = BIT_WIDTH * 4; + constexpr u32 second_group_byte_offset = second_group_bit_offset / 8; + constexpr u32 second_group_shift = second_group_bit_offset % 8; + + alignas(16) u64 raw_values[2]{}; + std::memcpy(raw_values, input + second_group_byte_offset, packed_bytes - second_group_byte_offset); + + temp_values[2] = (raw_values[0] >> second_group_shift) | ( + raw_values[1] << (64 - second_group_shift)); + temp_values[3] = raw_values[1] >> second_group_shift; + } + + alignas(16) u64 values[4]; + values[0] = _pdep_u64((temp_values[0]), pdep_mask); + values[1] = _pdep_u64((temp_values[0] >> shift1) | (temp_values[1] << shift2), pdep_mask); + values[2] = _pdep_u64((temp_values[2]), pdep_mask); + values[3] = _pdep_u64((temp_values[2] >> shift1) | (temp_values[3] << shift2), pdep_mask); + + result = _mm256_set_epi64x(static_cast(values[3]), static_cast(values[2]), + static_cast(values[1]), + static_cast(values[0])); + } + + if constexpr (SIGN_VALUES) { + result = internal::mm256_sign_extend32(result); + } + return result; + } + } // namespace internal + + /** + * @brief Decompress a single 512\-bit block using AVX2 and BMI2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 16). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). + * + * @param input pointer to the start of the compressed block. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @return int status code (0 for success). + * + * @note This function requires AVX2 and BMI2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + + const __m256 scale_v = _mm256_set1_ps(scale); +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); + const __m256 dequantized = internal::mm256_dequantize_epi32(unpacked, scale_v); + _mm256_storeu_ps(output, dequantized); + input += BIT_WIDTH; + output += 8; + } + + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi32(tail_values[i], scale); + } + } + + return 0; + } + + /** + * @brief Decompress a single block to double values using AVX2 and BMI2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 16). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). + * + * @param input pointer to the start of the compressed block. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed double values will be stored. + * @return int status code (0 for success). + * + * @note This function requires AVX2 and BMI2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_block_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + constexpr u32 elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + constexpr u32 iterations_8 = elements_per_block / 8; + constexpr u8 remaining = elements_per_block - iterations_8 * 8; + const __m256d scale_v = _mm256_set1_pd(scale); +#pragma GCC unroll 4 + for (u32 iter = 0; iter < iterations_8; iter++) { + const __m256i unpacked = internal::mm256_unpack_epi32_bmi2(input); + const __m256i extend1 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(unpacked)); + const __m256i extend2 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(unpacked, 1)); + + const __m256d dequantized1 = internal::mm256_dequantize_epi64_pd(extend1, scale_v); + const __m256d dequantized2 = internal::mm256_dequantize_epi64_pd(extend2, scale_v); + + _mm256_storeu_pd(output, dequantized1); + _mm256_storeu_pd(output + 4, dequantized2); + + input += BIT_WIDTH; + output += 8; + } + + if constexpr (remaining > 0) { + const std::vector tail_values = internal::unpack_epi32_fallback < BIT_WIDTH, SIGN_VALUES + > + (input, remaining); + for (u32 i = 0; i < remaining; i++) { + output[i] = internal::dequantize_epi64(tail_values[i], scale); + } + } + + return 0; + } + + /** + * @brief Decompress multiple 512\-bit blocks using AVX2 and BMI2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). + * + * @param input pointer to the start of the compressed data. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed float values will be stored. + * @param blocks number of 512-bit blocks to decompress. + * @return int status code (0 for success). + * + * @note This function requires AVX2 and BMI2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_bmi2(const void * __restrict__ input_ptr, const f32 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f32 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_bmi2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } + + /** + * @brief Decompress multiple blocks to double values using AVX2 and BMI2 instructions. + * + * @tparam BIT_WIDTH bit width per value in the packed representation (1 to 24). + * @tparam SIGN_VALUES whether the values are signed or unsigned. + * @tparam BLOCK_SIZE size of the block in bytes (must be a multiple of 32). + * + * @param input pointer to the start of the compressed data. + * @param scale scaling factor used during quantization. + * @param output pointer to the output buffer where decompressed double values will be stored. + * @param blocks number of blocks to decompress. + * @return int status code (0 for success). + * + * @note This function requires AVX2 and BMI2 support. + */ + template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) + int mm256_decompress_blocks_bmi2(const void * __restrict__ input_ptr, const f64 scale, + void * __restrict__ output_ptr, + const u32 blocks) { + const auto *input = static_cast(input_ptr); + auto *output = static_cast(output_ptr); + + const u8 *block_input = input; + f64 *block_output = output; + + for (u32 block = 0; block < blocks; block++) { + mm256_decompress_block_bmi2(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } +} // namespace pernix + +#endif // PERNIX_BMI2_DECOMPRESSION_H diff --git a/src/internal/pernix/x86/utils.h b/src/internal/pernix/x86/utils.h new file mode 100644 index 0000000..42aad76 --- /dev/null +++ b/src/internal/pernix/x86/utils.h @@ -0,0 +1,14 @@ +#ifndef PERNIX_X86_UTILS_H +#define PERNIX_X86_UTILS_H + +#include + +namespace pernix::x86::internal { + static constexpr u32 tail_bytes(const u8 bit_width, const u32 remaining_elements) { + const u32 tail_bits = remaining_elements * bit_width; + const u32 tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; + } +} // namespace pernix::x86::internal + +#endif // PERNIX_X86_UTILS_H diff --git a/src/pernix.cpp b/src/pernix.cpp index 87ccf9d..a57a2e5 100644 --- a/src/pernix.cpp +++ b/src/pernix.cpp @@ -1,124 +1,173 @@ #include +#include + +namespace { + bool is_valid_block_size(u32 block_size) { + return block_size == 64 || block_size == 128 || block_size == 256 || block_size == 512 || block_size == 1024; + } +} -#ifdef __cplusplus -namespace pernix { extern "C" { -#endif +pernix_status pernix_compress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -// Use the best available implementation based on detected CPU features at compile time -#ifdef PERNIX_AVX2_ENABLED -#ifdef PERNIX_AVX512_VBMI_ENABLED -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_avx512vbmi(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_compress_block_f32(static_cast(backend), bit_width, + block_size); -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm512_compress_block_f64_avx512vbmi(bit_width, input, scale, output); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm512_compress_blocks_avx512vbmi(bit_width, input, scale, output, blocks); + return static_cast(kernel.func(input, scale, output)); } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm512_compress_blocks_f64_avx512vbmi(bit_width, input, scale, output, blocks); -} +pernix_status pernix_compress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output, u32 blocks) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm512_decompress_block_avx512vbmi(bit_width, input, scale, output); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm512_decompress_block_f64_avx512vbmi(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_compress_blocks_f32(static_cast(backend), bit_width, + block_size); -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return mm512_decompress_blocks_avx512vbmi(bit_width, input, scale, output, blocks); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return mm512_decompress_blocks_f64_avx512vbmi(bit_width, input, scale, output, blocks); -} -#else -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_avx2(bit_width, input, scale, output); + return static_cast(kernel.func(input, scale, output, blocks)); } -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return mm256_compress_block_f64_avx2(bit_width, input, scale, output); -} +pernix_status pernix_decompress_block_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm256_compress_blocks_avx2(bit_width, input, scale, output, blocks); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return mm256_compress_blocks_f64_avx2(bit_width, input, scale, output, blocks); -} + const auto kernel = pernix::internal::select_decompress_block_f32(static_cast(backend), bit_width, + block_size, + sign_values); -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return mm256_decompress_block_avx2(bit_width, input, scale, output); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return mm256_decompress_block_f64_avx2(bit_width, input, scale, output); + return static_cast(kernel.func(input, scale, output)); } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return mm256_decompress_blocks_avx2(bit_width, input, scale, output, blocks); -} +pernix_status pernix_decompress_blocks_f32(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + float scale, void *output, u32 blocks, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return mm256_decompress_blocks_f64_avx2(bit_width, input, scale, output, blocks); -} -#endif -#else -int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback(bit_width, input, scale, output); -} + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } -int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { - return compress_block_fallback_f64(bit_width, input, scale, output); -} + const auto kernel = pernix::internal::select_decompress_blocks_f32(static_cast(backend), bit_width, + block_size, + sign_values); -int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return compress_blocks_fallback(bit_width, input, scale, output, blocks); -} + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } -int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, - const uint32_t blocks) { - return compress_blocks_fallback_f64(bit_width, input, scale, output, blocks); + return static_cast(kernel.func(input, scale, output, blocks)); } -int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - return decompress_block_fallback(bit_width, input, scale, output); -} +pernix_status pernix_compress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } + + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } + + const auto kernel = pernix::internal::select_compress_block_f64(static_cast(backend), bit_width, + block_size); -int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - return decompress_block_fallback_f64(bit_width, input, scale, output); + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } + + return static_cast(kernel.func(input, scale, output)); } -int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - return decompress_blocks_fallback(bit_width, input, scale, output, blocks); +pernix_status pernix_compress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output, u32 blocks) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } + + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } + + const auto kernel = pernix::internal::select_compress_blocks_f64(static_cast(backend), bit_width, + block_size); + + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } + + return static_cast(kernel.func(input, scale, output, blocks)); } -int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { - return decompress_blocks_fallback_f64(bit_width, input, scale, output, blocks); +pernix_status pernix_decompress_block_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } + + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } + + const auto kernel = pernix::internal::select_decompress_block_f64(static_cast(backend), bit_width, + block_size, + sign_values); + + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } + + return static_cast(kernel.func(input, scale, output)); } -#endif -#ifdef __cplusplus +pernix_status pernix_decompress_blocks_f64(pernix_backend backend, u8 bit_width, u32 block_size, const void *input, + double scale, void *output, u32 blocks, bool sign_values) { + if (input == nullptr || output == nullptr) { + return PERNIX_STATUS_INVALID_ARGUMENT; + } + + if (!is_valid_block_size(block_size)) { + return PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE; + } + + const auto kernel = pernix::internal::select_decompress_blocks_f64(static_cast(backend), bit_width, + block_size, + sign_values); + + if (!kernel) { + return PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH; + } + + return static_cast(kernel.func(input, scale, output, blocks)); +} } -} // namespace pernix -#endif // __cplusplus \ No newline at end of file diff --git a/src/x86/avx2/avx2_compression.cpp b/src/x86/avx2/avx2_compression.cpp new file mode 100644 index 0000000..548989b --- /dev/null +++ b/src/x86/avx2/avx2_compression.cpp @@ -0,0 +1,193 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_block_avx2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_blocks_avx2) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_block_avx2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("avx2", &mm256_compress_blocks_avx2) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } + + Kernel select_avx2_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"avx2", nullptr}; + } + } + + Kernel select_avx2_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"avx2", nullptr}; + } + } + + Kernel select_avx2_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"avx2", nullptr}; + } + } + + Kernel select_avx2_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"avx2", nullptr}; + } + } + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/x86/avx2/avx2_decompression.cpp b/src/x86/avx2/avx2_decompression.cpp new file mode 100644 index 0000000..cf8a53c --- /dev/null +++ b/src/x86/avx2/avx2_decompression.cpp @@ -0,0 +1,205 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_block_avx2); \ + return Kernel("avx2", &mm256_decompress_block_avx2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_blocks_avx2); \ + return Kernel("avx2", &mm256_decompress_blocks_avx2) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_block_avx2); \ + return Kernel("avx2", &mm256_decompress_block_avx2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx2", &mm256_decompress_blocks_avx2); \ + return Kernel("avx2", &mm256_decompress_blocks_avx2) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx2", nullptr}; \ + } \ + break + + Kernel select_avx2_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"avx2", nullptr}; + } + } + + Kernel select_avx2_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"avx2", nullptr}; + } + } + + Kernel select_avx2_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"avx2", nullptr}; + } + } + + Kernel select_avx2_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"avx2", nullptr}; + } + } + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/x86/avx2/compression.cpp b/src/x86/avx2/compression.cpp deleted file mode 100644 index 231fc75..0000000 --- a/src/x86/avx2/compression.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX2_ENABLED - -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_compress_block_avx2(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_compress_blocks_avx2(input, scale, output, blocks); - -int mm256_compress_block_avx2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_block_f64_avx2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_avx2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_f64_avx2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED diff --git a/src/x86/avx2/decompression.cpp b/src/x86/avx2/decompression.cpp deleted file mode 100644 index 47c473a..0000000 --- a/src/x86/avx2/decompression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX2_ENABLED -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_decompress_block_avx2(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_decompress_blocks_avx2(input, scale, output, blocks); - -int mm256_decompress_block_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_block_f64_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_f64_avx2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED \ No newline at end of file diff --git a/src/x86/avx512vbmi/avx512vbmi_compression.cpp b/src/x86/avx512vbmi/avx512vbmi_compression.cpp new file mode 100644 index 0000000..8eb25c5 --- /dev/null +++ b/src/x86/avx512vbmi/avx512vbmi_compression.cpp @@ -0,0 +1,193 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_block_avx512vbmi) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_blocks_avx512vbmi) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_block_avx512vbmi) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("avx512vbmi", &mm512_compress_blocks_avx512vbmi) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } + + Kernel select_avx512vbmi_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"avx512vbmi", nullptr}; + } + } + + Kernel select_avx512vbmi_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"avx512vbmi", nullptr}; + } + } + + Kernel select_avx512vbmi_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"avx512vbmi", nullptr}; + } + } + + Kernel select_avx512vbmi_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"avx512vbmi", nullptr}; + } + } + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/x86/avx512vbmi/avx512vbmi_decompression.cpp b/src/x86/avx512vbmi/avx512vbmi_decompression.cpp new file mode 100644 index 0000000..ac2cbcd --- /dev/null +++ b/src/x86/avx512vbmi/avx512vbmi_decompression.cpp @@ -0,0 +1,205 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_block_avx512vbmi) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi); \ + return Kernel("avx512vbmi", &mm512_decompress_blocks_avx512vbmi) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"avx512vbmi", nullptr}; \ + } \ + break + + Kernel select_avx512vbmi_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"avx512vbmi", nullptr}; + } + } + + Kernel select_avx512vbmi_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"avx512vbmi", nullptr}; + } + } + + Kernel select_avx512vbmi_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"avx512vbmi", nullptr}; + } + } + + Kernel select_avx512vbmi_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"avx512vbmi", nullptr}; + } + } + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/x86/avx512vbmi/compression.cpp b/src/x86/avx512vbmi/compression.cpp deleted file mode 100644 index bdc6635..0000000 --- a/src/x86/avx512vbmi/compression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_AVX512_VBMI_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm512_compress_block_avx512vbmi(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm512_compress_blocks_avx512vbmi(input, scale, output, blocks); - -int mm512_compress_block_avx512vbmi(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_compress_block_f64_avx512vbmi(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_compress_blocks_avx512vbmi(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm512_compress_blocks_f64_avx512vbmi(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_AVX512_VBMI_ENABLED) \ No newline at end of file diff --git a/src/x86/avx512vbmi/decompression.cpp b/src/x86/avx512vbmi/decompression.cpp deleted file mode 100644 index 1273f6a..0000000 --- a/src/x86/avx512vbmi/decompression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_AVX512_VBMI_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm512_decompress_block_avx512vbmi(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm512_decompress_blocks_avx512vbmi(input, scale, output, blocks); - -int mm512_decompress_block_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_decompress_block_f64_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm512_decompress_blocks_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm512_decompress_blocks_f64_avx512vbmi(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED && PERNIX_AVX512_VBMI_ENABLED \ No newline at end of file diff --git a/src/x86/bmi2/bmi2_compression.cpp b/src/x86/bmi2/bmi2_compression.cpp new file mode 100644 index 0000000..e978ba2 --- /dev/null +++ b/src/x86/bmi2/bmi2_compression.cpp @@ -0,0 +1,193 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_COMPRESS_BLOCK_32(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_block_bmi2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_32(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_blocks_bmi2) + +#define PERNIX_CASE_COMPRESS_BLOCK_64(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_block_bmi2) + +#define PERNIX_CASE_COMPRESS_BLOCKS_64(N, BS) \ +case N: return Kernel("bmi2", &mm256_compress_blocks_bmi2) + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCK_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + +#define PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_COMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_COMPRESS_BLOCKS_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } + + Kernel select_bmi2_compress_block_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32(1024); + default: + return {"bmi2", nullptr}; + } + } + + Kernel select_bmi2_compress_blocks_f32(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32(1024); + default: + return {"bmi2", nullptr}; + } + } + + Kernel select_bmi2_compress_block_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64(1024); + default: + return {"bmi2", nullptr}; + } + } + + Kernel select_bmi2_compress_blocks_f64(const u8 bit_width, const u32 block_size) { + switch (block_size) { + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64(1024); + default: + return {"bmi2", nullptr}; + } + } + +#undef PERNIX_CASE_COMPRESS_BLOCK_32 +#undef PERNIX_CASE_COMPRESS_BLOCKS_32 +#undef PERNIX_CASE_COMPRESS_BLOCK_64 +#undef PERNIX_CASE_COMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_COMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_COMPRESS_BLOCKS_SWITCH_64 +} diff --git a/src/x86/bmi2/bmi2_decompression.cpp b/src/x86/bmi2/bmi2_decompression.cpp new file mode 100644 index 0000000..84deda8 --- /dev/null +++ b/src/x86/bmi2/bmi2_decompression.cpp @@ -0,0 +1,205 @@ +#include +#include + +namespace pernix::internal { +#define PERNIX_CASE_DECOMPRESS_BLOCK_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_block_bmi2); \ + return Kernel("bmi2", &mm256_decompress_block_bmi2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_32(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_blocks_bmi2); \ + return Kernel("bmi2", &mm256_decompress_blocks_bmi2) + + +#define PERNIX_CASE_DECOMPRESS_BLOCK_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_block_bmi2); \ + return Kernel("bmi2", &mm256_decompress_block_bmi2) + +#define PERNIX_CASE_DECOMPRESS_BLOCKS_64(N, BS) \ +case N: \ + if (sign_values) return Kernel("bmi2", &mm256_decompress_blocks_bmi2); \ + return Kernel("bmi2", &mm256_decompress_blocks_bmi2) +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_32(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCK_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + +#define PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(BS) \ + case BS: \ + switch (bit_width) { \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(1, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(2, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(3, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(4, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(5, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(6, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(7, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(8, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(9, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(10, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(11, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(12, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(13, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(14, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(15, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(16, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(17, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(18, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(19, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(20, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(21, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(22, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(23, BS); \ + PERNIX_CASE_DECOMPRESS_BLOCKS_64(24, BS); \ + default: return {"bmi2", nullptr}; \ + } \ + break + + Kernel select_bmi2_decompress_block_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32(1024); + default: return {"bmi2", nullptr}; + } + } + + Kernel select_bmi2_decompress_blocks_f32(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32(1024); + default: return {"bmi2", nullptr}; + } + } + + Kernel select_bmi2_decompress_block_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64(1024); + default: return {"bmi2", nullptr}; + } + } + + Kernel select_bmi2_decompress_blocks_f64(const u8 bit_width, const u32 block_size, + bool sign_values) { + switch (block_size) { + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(64); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(128); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(256); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(512); + PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64(1024); + default: return {"bmi2", nullptr}; + } + } + +#undef PERNIX_CASE_DECOMPRESS_BLOCK_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_32 +#undef PERNIX_CASE_DECOMPRESS_BLOCK_64 +#undef PERNIX_CASE_DECOMPRESS_BLOCKS_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_32 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_64 +#undef PERNIX_BLOCK_SIZE_DECOMPRESS_SWITCH_BLOCKS_64 +} diff --git a/src/x86/bmi2/compression.cpp b/src/x86/bmi2/compression.cpp deleted file mode 100644 index 79c43d8..0000000 --- a/src/x86/bmi2/compression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_BMI2_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_COMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_compress_block_bmi2(input, scale, output); - -#define PERNIX_COMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_compress_blocks_bmi2(input, scale, output, blocks); - -int mm256_compress_block_bmi2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_block_f64_bmi2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCK_CASE(1) - PERNIX_COMPRESS_BLOCK_CASE(2) - PERNIX_COMPRESS_BLOCK_CASE(3) - PERNIX_COMPRESS_BLOCK_CASE(4) - PERNIX_COMPRESS_BLOCK_CASE(5) - PERNIX_COMPRESS_BLOCK_CASE(6) - PERNIX_COMPRESS_BLOCK_CASE(7) - PERNIX_COMPRESS_BLOCK_CASE(8) - PERNIX_COMPRESS_BLOCK_CASE(9) - PERNIX_COMPRESS_BLOCK_CASE(10) - PERNIX_COMPRESS_BLOCK_CASE(11) - PERNIX_COMPRESS_BLOCK_CASE(12) - PERNIX_COMPRESS_BLOCK_CASE(13) - PERNIX_COMPRESS_BLOCK_CASE(14) - PERNIX_COMPRESS_BLOCK_CASE(15) - PERNIX_COMPRESS_BLOCK_CASE(16) - PERNIX_COMPRESS_BLOCK_CASE(17) - PERNIX_COMPRESS_BLOCK_CASE(18) - PERNIX_COMPRESS_BLOCK_CASE(19) - PERNIX_COMPRESS_BLOCK_CASE(20) - PERNIX_COMPRESS_BLOCK_CASE(21) - PERNIX_COMPRESS_BLOCK_CASE(22) - PERNIX_COMPRESS_BLOCK_CASE(23) - PERNIX_COMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_bmi2(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_compress_blocks_f64_bmi2(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_COMPRESS_BLOCKS_CASE(1) - PERNIX_COMPRESS_BLOCKS_CASE(2) - PERNIX_COMPRESS_BLOCKS_CASE(3) - PERNIX_COMPRESS_BLOCKS_CASE(4) - PERNIX_COMPRESS_BLOCKS_CASE(5) - PERNIX_COMPRESS_BLOCKS_CASE(6) - PERNIX_COMPRESS_BLOCKS_CASE(7) - PERNIX_COMPRESS_BLOCKS_CASE(8) - PERNIX_COMPRESS_BLOCKS_CASE(9) - PERNIX_COMPRESS_BLOCKS_CASE(10) - PERNIX_COMPRESS_BLOCKS_CASE(11) - PERNIX_COMPRESS_BLOCKS_CASE(12) - PERNIX_COMPRESS_BLOCKS_CASE(13) - PERNIX_COMPRESS_BLOCKS_CASE(14) - PERNIX_COMPRESS_BLOCKS_CASE(15) - PERNIX_COMPRESS_BLOCKS_CASE(16) - PERNIX_COMPRESS_BLOCKS_CASE(17) - PERNIX_COMPRESS_BLOCKS_CASE(18) - PERNIX_COMPRESS_BLOCKS_CASE(19) - PERNIX_COMPRESS_BLOCKS_CASE(20) - PERNIX_COMPRESS_BLOCKS_CASE(21) - PERNIX_COMPRESS_BLOCKS_CASE(22) - PERNIX_COMPRESS_BLOCKS_CASE(23) - PERNIX_COMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED && PERNIX_BMI2_ENABLED diff --git a/src/x86/bmi2/decompression.cpp b/src/x86/bmi2/decompression.cpp deleted file mode 100644 index c9cbad7..0000000 --- a/src/x86/bmi2/decompression.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include -#include - -#if defined(PERNIX_AVX2_ENABLED) && defined(PERNIX_BMI2_ENABLED) -#ifdef __cplusplus -namespace pernix { -extern "C" { -#endif - -#define PERNIX_DECOMPRESS_BLOCK_CASE(N) \ - case N: \ - return mm256_decompress_block_bmi2(input, scale, output); - -#define PERNIX_DECOMPRESS_BLOCKS_CASE(N) \ - case N: \ - return mm256_decompress_blocks_bmi2(input, scale, output, blocks); - -int mm256_decompress_block_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_block_f64_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCK_CASE(1) - PERNIX_DECOMPRESS_BLOCK_CASE(2) - PERNIX_DECOMPRESS_BLOCK_CASE(3) - PERNIX_DECOMPRESS_BLOCK_CASE(4) - PERNIX_DECOMPRESS_BLOCK_CASE(5) - PERNIX_DECOMPRESS_BLOCK_CASE(6) - PERNIX_DECOMPRESS_BLOCK_CASE(7) - PERNIX_DECOMPRESS_BLOCK_CASE(8) - PERNIX_DECOMPRESS_BLOCK_CASE(9) - PERNIX_DECOMPRESS_BLOCK_CASE(10) - PERNIX_DECOMPRESS_BLOCK_CASE(11) - PERNIX_DECOMPRESS_BLOCK_CASE(12) - PERNIX_DECOMPRESS_BLOCK_CASE(13) - PERNIX_DECOMPRESS_BLOCK_CASE(14) - PERNIX_DECOMPRESS_BLOCK_CASE(15) - PERNIX_DECOMPRESS_BLOCK_CASE(16) - PERNIX_DECOMPRESS_BLOCK_CASE(17) - PERNIX_DECOMPRESS_BLOCK_CASE(18) - PERNIX_DECOMPRESS_BLOCK_CASE(19) - PERNIX_DECOMPRESS_BLOCK_CASE(20) - PERNIX_DECOMPRESS_BLOCK_CASE(21) - PERNIX_DECOMPRESS_BLOCK_CASE(22) - PERNIX_DECOMPRESS_BLOCK_CASE(23) - PERNIX_DECOMPRESS_BLOCK_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -int mm256_decompress_blocks_f64_bmi2(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output, const uint32_t blocks) { - switch (bit_width) { - PERNIX_DECOMPRESS_BLOCKS_CASE(1) - PERNIX_DECOMPRESS_BLOCKS_CASE(2) - PERNIX_DECOMPRESS_BLOCKS_CASE(3) - PERNIX_DECOMPRESS_BLOCKS_CASE(4) - PERNIX_DECOMPRESS_BLOCKS_CASE(5) - PERNIX_DECOMPRESS_BLOCKS_CASE(6) - PERNIX_DECOMPRESS_BLOCKS_CASE(7) - PERNIX_DECOMPRESS_BLOCKS_CASE(8) - PERNIX_DECOMPRESS_BLOCKS_CASE(9) - PERNIX_DECOMPRESS_BLOCKS_CASE(10) - PERNIX_DECOMPRESS_BLOCKS_CASE(11) - PERNIX_DECOMPRESS_BLOCKS_CASE(12) - PERNIX_DECOMPRESS_BLOCKS_CASE(13) - PERNIX_DECOMPRESS_BLOCKS_CASE(14) - PERNIX_DECOMPRESS_BLOCKS_CASE(15) - PERNIX_DECOMPRESS_BLOCKS_CASE(16) - PERNIX_DECOMPRESS_BLOCKS_CASE(17) - PERNIX_DECOMPRESS_BLOCKS_CASE(18) - PERNIX_DECOMPRESS_BLOCKS_CASE(19) - PERNIX_DECOMPRESS_BLOCKS_CASE(20) - PERNIX_DECOMPRESS_BLOCKS_CASE(21) - PERNIX_DECOMPRESS_BLOCKS_CASE(22) - PERNIX_DECOMPRESS_BLOCKS_CASE(23) - PERNIX_DECOMPRESS_BLOCKS_CASE(24) - default: - return -1; - } -} - -#undef PERNIX_COMPRESS_BLOCK_CASE -#undef PERNIX_COMPRESS_BLOCKS_CASE - -#ifdef __cplusplus -} -} // namespace pernix -#endif // __cplusplus -#endif // PERNIX_AVX2_ENABLED && PERNIX_BMI2_ENABLED \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 370ec3b..62bb90e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,18 +1,7 @@ find_package(PkgConfig) pkg_search_module(GTEST REQUIRED gtest) -include(CheckCXXCompilerFlag) -file(GLOB_RECURSE - SOURCE_FILES - CONFIGURE_DEPENDS - *.cpp -) - -file(GLOB_RECURSE - HEADER_FILES - CONFIGURE_DEPENDS - ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h -) +file(GLOB PERNIX_TEST_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) set(PERNIX_TEST_BLOCK_SIZES 64 128 256 512 CACHE STRING "Block sizes to build as separate test executables") set(PERNIX_TEST_TARGETS) @@ -22,7 +11,7 @@ foreach (BLOCK_SIZE IN LISTS PERNIX_TEST_BLOCK_SIZES) list(APPEND PERNIX_TEST_TARGETS ${TEST_TARGET}) add_executable(${TEST_TARGET}) - target_sources(${TEST_TARGET} PRIVATE ${SOURCE_FILES} ${HEADER_FILES}) + target_sources(${TEST_TARGET} PRIVATE ${PERNIX_TEST_SOURCE_FILES}) target_link_libraries(${TEST_TARGET} PRIVATE pernix ${GTEST_LDFLAGS}) target_compile_features(${TEST_TARGET} PRIVATE cxx_std_20) target_compile_options(${TEST_TARGET} PRIVATE ${GTEST_CFLAGS}) @@ -32,11 +21,6 @@ endforeach () add_custom_target(pernix_tests DEPENDS ${PERNIX_TEST_TARGETS}) -# check_cxx_compiler_flag(-O3 HAS_OPTIMIZE3_FLAG) -# if (HAS_OPTIMIZE3_FLAG) -# target_compile_options(pernix_tests PRIVATE -O3) -# endif () - include(GoogleTest) foreach (BLOCK_SIZE IN LISTS PERNIX_TEST_BLOCK_SIZES) gtest_discover_tests(pernix_tests_${BLOCK_SIZE} @@ -50,13 +34,11 @@ endforeach () if (PERNIX_ENABLE_CODE_COVERAGE) message(STATUS "Code coverage enabled for tests") - # set compile and link options for code coverage foreach (TEST_TARGET IN LISTS PERNIX_TEST_TARGETS) target_compile_options(${TEST_TARGET} PRIVATE -g -O0 --coverage) target_link_options(${TEST_TARGET} PRIVATE --coverage) endforeach () - # find lcov and genhtml find_program(LCOV_PROGRAM lcov) find_program(GENHTML_PROGRAM genhtml) if (NOT LCOV_PROGRAM OR NOT GENHTML_PROGRAM) diff --git a/tests/compression/avx2_compression_tests.cpp b/tests/compression/avx2_compression_tests.cpp deleted file mode 100644 index 1c2892b..0000000 --- a/tests/compression/avx2_compression_tests.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -#ifdef PERNIX_AVX2_ENABLED - -TYPED_TEST(CompressionTest, AVX2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_avx2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -TYPED_TEST(CompressionTest64, AVX2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_avx2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -#endif // PERNIX_AVX2_ENABLED diff --git a/tests/compression/avx512vbmi_compression_tests.cpp b/tests/compression/avx512vbmi_compression_tests.cpp deleted file mode 100644 index a6cb71d..0000000 --- a/tests/compression/avx512vbmi_compression_tests.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include - -#ifdef PERNIX_AVX512_VBMI_ENABLED - -TYPED_TEST(CompressionTest, AVX512VBMICompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm512_compress_block_avx512vbmi( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -TYPED_TEST(CompressionTest64, AVX512VBMICompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm512_compress_block_avx512vbmi( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -#endif // PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/compression/bmi2_compression_tests.cpp b/tests/compression/bmi2_compression_tests.cpp deleted file mode 100644 index 85d3cac..0000000 --- a/tests/compression/bmi2_compression_tests.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -#ifdef PERNIX_BMI2_ENABLED - -TYPED_TEST(CompressionTest, BMI2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_bmi2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -TYPED_TEST(CompressionTest64, BMI2CompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::mm256_compress_block_bmi2( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - std::vector restored(this->testSet.elementsPerBlock); - pernix::decompress_block_fallback( - compressedData[block].data(), this->testSet.getScales()[block], restored.data()); - - expectDecompressedBlockNearSource(*this, restored, block); - } -} - -#endif // PERNIX_BMI2_ENABLED diff --git a/tests/compression/fallback_compression_tests.cpp b/tests/compression/fallback_compression_tests.cpp deleted file mode 100644 index 9b50109..0000000 --- a/tests/compression/fallback_compression_tests.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -TYPED_TEST(CompressionTest, FallbackCompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::compress_block_fallback( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectCompressedBlockEqualsReference(*this, compressedData[block], block); - } -} - -TYPED_TEST(CompressionTest64, FallbackCompressBlock) { - std::vector> compressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - compressedData[block].resize(TestFixture::BlockSize); - - pernix::compress_block_fallback( - this->testSet.getDecompressedData()[block].data(), 1 / this->testSet.getScales()[block], - reinterpret_cast(compressedData[block].data())); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectCompressedBlockEqualsReference(*this, compressedData[block], block); - } -} diff --git a/tests/decompression/avx2_decompression_tests.cpp b/tests/decompression/avx2_decompression_tests.cpp deleted file mode 100644 index e0f039f..0000000 --- a/tests/decompression/avx2_decompression_tests.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -#ifdef PERNIX_AVX2_ENABLED - -TYPED_TEST(DecompressionTest, AVX2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_avx2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, AVX2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_avx2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif // PERNIX_AVX2_ENABLED diff --git a/tests/decompression/avx512vbmi_decompression_tests.cpp b/tests/decompression/avx512vbmi_decompression_tests.cpp deleted file mode 100644 index 446443a..0000000 --- a/tests/decompression/avx512vbmi_decompression_tests.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -#ifdef PERNIX_AVX512_VBMI_ENABLED - -TYPED_TEST(DecompressionTest, AVX512VBMIDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm512_decompress_block_avx512vbmi( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, AVX512VBMIDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm512_decompress_block_avx512vbmi( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif // PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/decompression/bmi2_decompression_tests.cpp b/tests/decompression/bmi2_decompression_tests.cpp deleted file mode 100644 index 11a8efb..0000000 --- a/tests/decompression/bmi2_decompression_tests.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -#ifdef PERNIX_BMI2_ENABLED - -TYPED_TEST(DecompressionTest, BMI2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_bmi2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, BMI2DecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::mm256_decompress_block_bmi2( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -#endif // PERNIX_BMI2_ENABLED diff --git a/tests/decompression/fallback_decompression_tests.cpp b/tests/decompression/fallback_decompression_tests.cpp deleted file mode 100644 index 3c5dc1f..0000000 --- a/tests/decompression/fallback_decompression_tests.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include <../../include/pernix/pernix.h> -#include - -TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::decompress_block_fallback( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} - -TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { - std::vector> decompressedData(this->testSet.numberOfBlocks); - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - decompressedData[block].resize(this->testSet.elementsPerBlock); - - pernix::decompress_block_fallback( - this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); - } - - for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { - expectDecompressedBlockNearSource(*this, decompressedData[block], block); - } -} diff --git a/tests/fallback_edge_tests.cpp b/tests/fallback_edge_tests.cpp deleted file mode 100644 index 2bdee5a..0000000 --- a/tests/fallback_edge_tests.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include -#include - -#include - -#include -#include -#include -#include - -TEST(FallbackDecompressionEdgeTest, SignExtensionIsWellDefinedForNegativeValues) { - const std::array input{0x08}; - std::array output{}; - - ASSERT_EQ(pernix::decompress_block_fallback<4>(input.data(), 1.0F, output.data()), 0); - - EXPECT_EQ(output[0], -8.0F); -} - -TEST(FallbackCompressionEdgeTest, ClearsUnusedPaddingBytes) { - std::array input{}; - std::array output{}; - output.fill(0xA5); - - ASSERT_EQ(pernix::compress_block_fallback<24>(input.data(), 1.0F, output.data()), 0); - - EXPECT_EQ(output[63], 0); -} - -TEST(FallbackCompressionEdgeTest, ClampsNonFiniteAndOutOfRangeBeforeNarrowing) { - std::array input{}; - input[0] = std::numeric_limits::infinity(); - input[1] = -std::numeric_limits::infinity(); - input[2] = std::numeric_limits::quiet_NaN(); - std::array compressed{}; - std::array restored{}; - - ASSERT_EQ(pernix::compress_block_fallback<4>(input.data(), 1.0F, compressed.data()), 0); - ASSERT_EQ(pernix::decompress_block_fallback<4>(compressed.data(), 1.0F, restored.data()), 0); - - EXPECT_EQ(restored[0], 7.0F); - EXPECT_EQ(restored[1], -8.0F); - EXPECT_EQ(restored[2], 0.0F); -} diff --git a/tests/fallback_tests.cpp b/tests/fallback_tests.cpp new file mode 100644 index 0000000..abb502b --- /dev/null +++ b/tests/fallback_tests.cpp @@ -0,0 +1,316 @@ +#include + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Fallback compress: verify byte-exact match against the reference +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, FallbackCompressBlock) { + std::vector > compressed(this->testSet.numberOfBlocks); + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + compressed[b].resize(TestFixture::BlockSize); + + const auto status = pernix_compress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getDecompressedData()[b].data(), 1.0f / this->testSet.getScales()[b], + compressed[b].data()); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + expectCompressedBlockEqualsReference(*this, compressed[b], b); + } +} + +TYPED_TEST(CompressionTest64, FallbackCompressBlock) { + std::vector > compressed(this->testSet.numberOfBlocks); + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + compressed[b].resize(TestFixture::BlockSize); + + const auto status = pernix_compress_block_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getDecompressedData()[b].data(), 1.0 / this->testSet.getScales()[b], + compressed[b].data()); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + expectCompressedBlockEqualsReference(*this, compressed[b], b); + } +} + +// --------------------------------------------------------------------------- +// Fallback decompress: verify near-source match +// --------------------------------------------------------------------------- + +TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { + std::vector > decompressed(this->testSet.numberOfBlocks); + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + decompressed[b].resize(this->testSet.elementsPerBlock); + + const auto status = pernix_decompress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getCompressedData()[b].data(), this->testSet.getScales()[b], + decompressed[b].data(), true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + expectDecompressedBlockNearSource(*this, decompressed[b], b); + } +} + +TYPED_TEST(DecompressionTest64, FallbackDecompressBlock) { + std::vector > decompressed(this->testSet.numberOfBlocks); + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + decompressed[b].resize(this->testSet.elementsPerBlock); + + const auto status = pernix_decompress_block_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + this->testSet.getCompressedData()[b].data(), this->testSet.getScales()[b], + decompressed[b].data(), true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + } + + for (u32 b = 0; b < this->testSet.numberOfBlocks; b++) { + expectDecompressedBlockNearSource(*this, decompressed[b], b); + } +} + +// --------------------------------------------------------------------------- +// Multi-block roundtrip via compress_blocks / decompress_blocks (fallback) +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, FallbackCompressBlocksRoundtrip) { + const u32 nb = this->testSet.numberOfBlocks; + const u32 epb = this->testSet.elementsPerBlock; + const u32 total = nb * epb; + + std::vector flat(total); + for (u32 b = 0; b < nb; b++) { + std::copy_n(this->testSet.getDecompressedData()[b].data(), epb, + flat.data() + b * epb); + } + + // Compute a single scale that covers all blocks + float max_abs = 0.0f; + for (u32 i = 0; i < total; i++) { + max_abs = std::max(max_abs, std::abs(flat[i])); + } + const float q = static_cast(decltype(this->testSet)::quantization_levels); + const float scale = (max_abs > 0.0f && q > 0.0f) ? (max_abs / q) : std::numeric_limits::epsilon(); + const float scale_inv = 1.0f / scale; + + std::vector compressed(nb * TestFixture::BlockSize); + auto status = pernix_compress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + flat.data(), scale_inv, compressed.data(), nb); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + std::vector restored(total); + status = pernix_decompress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, restored.data(), nb, true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + const float tol = (std::abs(scale) * 0.5f) + (std::numeric_limits::epsilon() * 16.0f); + for (u32 i = 0; i < total; i++) { + EXPECT_NEAR(restored[i], flat[i], tol); + } +} + +TYPED_TEST(CompressionTest64, FallbackCompressBlocksRoundtrip) { + const u32 nb = this->testSet.numberOfBlocks; + const u32 epb = this->testSet.elementsPerBlock; + const u32 total = nb * epb; + + std::vector flat(total); + for (u32 b = 0; b < nb; b++) { + std::copy_n(this->testSet.getDecompressedData()[b].data(), epb, + flat.data() + b * epb); + } + + // Compute a single scale that covers all blocks + double max_abs = 0.0; + for (u32 i = 0; i < total; i++) { + max_abs = std::max(max_abs, std::abs(flat[i])); + } + const double q = static_cast(decltype(this->testSet)::quantization_levels); + const double scale = (max_abs > 0.0 && q > 0.0) ? (max_abs / q) : std::numeric_limits::epsilon(); + const double scale_inv = 1.0 / scale; + + std::vector compressed(nb * TestFixture::BlockSize); + auto status = pernix_compress_blocks_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + flat.data(), scale_inv, compressed.data(), nb); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + std::vector restored(total); + status = pernix_decompress_blocks_f64( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, restored.data(), nb, true); + ASSERT_EQ(status, PERNIX_STATUS_OK); + + const double tol = (std::abs(scale) * 0.5) + (std::numeric_limits::epsilon() * 16.0); + for (u32 i = 0; i < total; i++) { + EXPECT_NEAR(restored[i], flat[i], tol); + } +} + +// --------------------------------------------------------------------------- +// blocks API with a single block should match the block API exactly +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, SingleBlockCompressBlocksMatchesBlock) { + const auto &src = this->testSet.getDecompressedData()[0]; + const float scale_inv = 1.0f / this->testSet.getScales()[0]; + + std::vector blockOut(TestFixture::BlockSize); + std::vector blocksOut(TestFixture::BlockSize); + + auto s1 = pernix_compress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + src.data(), scale_inv, blockOut.data()); + ASSERT_EQ(s1, PERNIX_STATUS_OK); + + auto s2 = pernix_compress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + src.data(), scale_inv, blocksOut.data(), 1); + ASSERT_EQ(s2, PERNIX_STATUS_OK); + + for (u32 i = 0; i < TestFixture::BlockSize; i++) { + EXPECT_EQ(blockOut[i], blocksOut[i]) << "byte " << i; + } +} + +TYPED_TEST(DecompressionTest, SingleBlockDecompressBlocksMatchesBlock) { + const auto &compressed = this->testSet.getCompressedData()[0]; + const float scale = this->testSet.getScales()[0]; + const u32 epb = this->testSet.elementsPerBlock; + + std::vector blockOut(epb); + std::vector blocksOut(epb); + + auto s1 = pernix_decompress_block_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, blockOut.data(), true); + ASSERT_EQ(s1, PERNIX_STATUS_OK); + + auto s2 = pernix_decompress_blocks_f32( + PERNIX_BACKEND_FALLBACK, TestFixture::BitWidth, TestFixture::BlockSize, + compressed.data(), scale, blocksOut.data(), 1, true); + ASSERT_EQ(s2, PERNIX_STATUS_OK); + + for (u32 i = 0; i < epb; i++) { + EXPECT_FLOAT_EQ(blockOut[i], blocksOut[i]) << "element " << i; + } +} + +// --------------------------------------------------------------------------- +// Edge-case behavioural tests (fallback, block_size=64) +// --------------------------------------------------------------------------- + +TEST(FallbackEdgeTest, SignExtensionIsWellDefinedForNegativeValues) { + constexpr u32 BS = 64; + const std::array input{0x08}; + + pernix_status st; + std::array output{}; + + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, 4, BS, input.data(), 1.0f, output.data(), true); + ASSERT_EQ(st, PERNIX_STATUS_OK); + EXPECT_EQ(output[0], -8.0f); +} + +TEST(FallbackEdgeTest, ClearsUnusedPaddingBytes) { + constexpr u32 BS = 64; + constexpr u32 BW = 24; + constexpr u32 EPB = (BS * 8) / BW; + + std::array input{}; + std::array output{}; + output.fill(0xA5); + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, input.data(), 1.0f, output.data()); + ASSERT_EQ(st, PERNIX_STATUS_OK); + EXPECT_EQ(output[BS - 1], 0); +} + +TEST(FallbackEdgeTest, ClampsNonFiniteAndOutOfRangeBeforeNarrowing) { + constexpr u32 BS = 64; + constexpr u32 BW = 4; + constexpr u32 EPB = (BS * 8) / BW; + + std::array input{}; + input[0] = std::numeric_limits::infinity(); + input[1] = -std::numeric_limits::infinity(); + input[2] = std::numeric_limits::quiet_NaN(); + + std::array compressed{}; + std::array restored{}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, input.data(), 1.0f, compressed.data()); + ASSERT_EQ(st, PERNIX_STATUS_OK); + + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, BW, BS, compressed.data(), 1.0f, restored.data(), true); + ASSERT_EQ(st, PERNIX_STATUS_OK); + + EXPECT_EQ(restored[0], 7.0f); + EXPECT_EQ(restored[1], -8.0f); + EXPECT_EQ(restored[2], 0.0f); +} + +// --------------------------------------------------------------------------- +// Error-code contract tests +// --------------------------------------------------------------------------- + +TEST(ErrorCodeTest, UnsupportedBlockSizeReturnsError) { + constexpr u32 BS = 32; + f32 src[32] = {}; + u8 dst[32] = {}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, BS, src, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); + + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, 8, BS, dst, 1.0f, src, true); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); + + st = pernix_compress_blocks_f32(PERNIX_BACKEND_FALLBACK, 8, BS, src, 1.0f, dst, 1); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); + + st = pernix_decompress_blocks_f32(PERNIX_BACKEND_FALLBACK, 8, BS, dst, 1.0f, src, 1, true); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BLOCK_SIZE); +} + +TEST(ErrorCodeTest, UnsupportedBitWidthReturnsError) { + constexpr u32 BS = 64; + f32 src[256] = {}; + u8 dst[64] = {}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 0, BS, src, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH); + + st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 25, BS, src, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_UNSUPPORTED_BIT_WIDTH); +} + +TEST(ErrorCodeTest, NullPointerReturnsError) { + f32 src[64] = {}; + u8 dst[64] = {}; + + auto st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, 64, nullptr, 1.0f, dst); + EXPECT_EQ(st, PERNIX_STATUS_INVALID_ARGUMENT); + + st = pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, 8, 64, src, 1.0f, nullptr); + EXPECT_EQ(st, PERNIX_STATUS_INVALID_ARGUMENT); +} diff --git a/tests/include/testset.h b/tests/include/testset.h index 6535957..b73fc71 100644 --- a/tests/include/testset.h +++ b/tests/include/testset.h @@ -1,7 +1,7 @@ #ifndef PERNIX_TESTSET_H #define PERNIX_TESTSET_H -#include <../../include/pernix/pernix.h> +#include #include #include @@ -21,22 +21,14 @@ static_assert(PERNIX_TEST_BLOCK_SIZE % 32 == 0, "PERNIX_TEST_BLOCK_SIZE must be dividable by 32 bytes"); -/** - * A test set for compression and decompression tests. - * It generates random float data, compresses it, and verifies the decompression using the fallback implementation. - * - * @tparam BIT_WIDTH The bit width used for compression (1 to 24). - * @tparam SIGN_VALUES Indicates whether the values are signed or unsigned. - */ -template +template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && std::is_floating_point_v class TestSet { - // using ValueType = std::conditional_t; - using ValueType = uint8_t; - using SeedType = std::mt19937::result_type; + using ValueType = u8; + using SeedType = std::mt19937::result_type; alignas(64) std::vector > compressedData; - alignas(64) std::vector > decompressedData; + alignas(64) std::vector > sourceData; alignas(64) std::vector scalesData; SeedType seed; @@ -44,192 +36,202 @@ class TestSet { std::uniform_real_distribution dis{}; public: - static constexpr uint32_t elementsPerBlock = (BLOCK_SIZE * 8) / BIT_WIDTH; - static constexpr SeedType defaultSeed = 0x5eed1234u; + static constexpr u32 elementsPerBlock = (BLOCK_SIZE * 8) / BIT_WIDTH; + static constexpr SeedType defaultSeed = 0x5eed1234u; static constexpr T quantization_levels = - SIGN_VALUES ? static_cast(BIT_WIDTH == 1 ? 1u : ((1u << (BIT_WIDTH - 1u)) - 1u)) : static_cast((1u << BIT_WIDTH) - 1u); + SIGN_VALUES + ? static_cast(BIT_WIDTH == 1 ? 1u : ((1u << (BIT_WIDTH - 1u)) - 1u)) + : static_cast((1u << BIT_WIDTH) - 1u); - uint32_t numberOfBlocks; + u32 numberOfBlocks; - [[nodiscard]] constexpr uint32_t totalElements() const { return numberOfBlocks * elementsPerBlock; } + [[nodiscard]] constexpr u32 totalElements() const { return numberOfBlocks * elementsPerBlock; } - [[nodiscard]] T blockTolerance(const uint32_t block) const { - // Half-step quantization bound + tiny FP slack for rounding edge cases. - return (std::abs(scalesData[block]) * static_cast(0.5)) + (std::numeric_limits::epsilon() * static_cast(16)); + [[nodiscard]] T blockTolerance(const u32 block) const { + return (std::abs(scalesData[block]) * static_cast(0.5)) + ( + std::numeric_limits::epsilon() * static_cast(16)); } - explicit TestSet(const uint32_t number_of_blocks, const SeedType initial_seed = testSeed()) + explicit TestSet(const u32 number_of_blocks, const SeedType initial_seed = testSeed()) : seed(initial_seed), gen(seed), numberOfBlocks(number_of_blocks) { compressedData.resize(numberOfBlocks); - decompressedData.resize(number_of_blocks); + sourceData.resize(number_of_blocks); scalesData.resize(numberOfBlocks); generateData(); } - [[nodiscard]] const std::vector& getScales() const { return scalesData; } + [[nodiscard]] const std::vector &getScales() const { return scalesData; } - [[nodiscard]] const std::vector >& getCompressedData() const { return compressedData; } + [[nodiscard]] const std::vector > &getCompressedData() const { return compressedData; } - [[nodiscard]] const std::vector >& getDecompressedData() const { return decompressedData; } + [[nodiscard]] const std::vector > &getDecompressedData() const { return sourceData; } [[nodiscard]] SeedType getSeed() const { return seed; } [[nodiscard]] static SeedType testSeed() { - const char* env_seed = std::getenv("PERNIX_TEST_SEED"); + const char *env_seed = std::getenv("PERNIX_TEST_SEED"); if (env_seed == nullptr || *env_seed == '\0') { return defaultSeed; } - char* end = nullptr; + char *end = nullptr; const unsigned long value = std::strtoul(env_seed, &end, 0); return (end != env_seed && *end == '\0') ? static_cast(value) : defaultSeed; } private: - // Generate deterministic source data and its fallback-compressed reference. void generateData() { - for (uint32_t i = 0; i < numberOfBlocks; i++) { + for (u32 i = 0; i < numberOfBlocks; i++) { compressedData[i].resize(BLOCK_SIZE); - decompressedData[i].resize(elementsPerBlock); + sourceData[i].resize(elementsPerBlock); - for (uint32_t j = 0; j < elementsPerBlock; j++) { - decompressedData[i][j] = dis(gen); + for (u32 j = 0; j < elementsPerBlock; j++) { + sourceData[i][j] = dis(gen); } - const T b_max = *std::ranges::max_element(decompressedData[i]); - const T b_min = *std::ranges::min_element(decompressedData[i]); + const T b_max = *std::ranges::max_element(sourceData[i]); + const T b_min = *std::ranges::min_element(sourceData[i]); const T b_abs = std::max(std::abs(b_max), std::abs(b_min)); scalesData[i] = (b_abs > static_cast(0) && quantization_levels > static_cast(0)) ? (b_abs / quantization_levels) : std::numeric_limits::epsilon(); - // Compress the data using the fallback implementation - pernix::compress_block_fallback(decompressedData[i].data(), 1 / scalesData[i], - reinterpret_cast(compressedData[i].data())); + if constexpr (std::is_same_v) { + pernix_compress_block_f32(PERNIX_BACKEND_FALLBACK, BIT_WIDTH, BLOCK_SIZE, + sourceData[i].data(), 1.0f / scalesData[i], compressedData[i].data()); + } else { + pernix_compress_block_f64(PERNIX_BACKEND_FALLBACK, BIT_WIDTH, BLOCK_SIZE, + sourceData[i].data(), 1.0 / scalesData[i], compressedData[i].data()); + } } } }; -template +template struct BitWidthBlockSize { - static constexpr uint8_t bit_width = BIT_WIDTH; - static constexpr uint32_t block_size = BLOCK_SIZE; + static constexpr u8 bit_width = BIT_WIDTH; + static constexpr u32 block_size = BLOCK_SIZE; }; using testing::Types; using BitWidthBlockSizeTypes = Types, BitWidthBlockSize<2, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<3, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<4, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<5, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<6, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<7, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<8, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<9, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<10, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<11, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<12, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<13, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<14, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<15, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<16, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<17, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<18, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<19, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<20, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<21, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<22, PERNIX_TEST_BLOCK_SIZE>, - BitWidthBlockSize<23, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<24, PERNIX_TEST_BLOCK_SIZE> >; + BitWidthBlockSize<3, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<4, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<5, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<6, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<7, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<8, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<9, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<10, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<11, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<12, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<13, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<14, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<15, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<16, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<17, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<18, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<19, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<20, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<21, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<22, PERNIX_TEST_BLOCK_SIZE>, + BitWidthBlockSize<23, PERNIX_TEST_BLOCK_SIZE>, BitWidthBlockSize<24, PERNIX_TEST_BLOCK_SIZE> >; struct BitWidthBlockSizeName { - template + template static std::string GetName(int) { std::ostringstream name; - name << "BitWidth" << static_cast(TestConfigT::bit_width) << "BlockSize" << TestConfigT::block_size; + name << "BitWidth" << static_cast(TestConfigT::bit_width) << "BlockSize" << TestConfigT::block_size; return name.str(); } }; -template +template class CompressionTest : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; CompressionTest() : testSet(1u << 10) { } }; -template +template class DecompressionTest : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; DecompressionTest() : testSet(1u << 10) { } }; -template +template class CompressionTest64 : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; CompressionTest64() : testSet(1u << 10) { } }; -template +template class DecompressionTest64 : public ::testing::Test { public: - static constexpr uint8_t BitWidth = TestConfigT::bit_width; - static constexpr uint32_t BlockSize = TestConfigT::block_size; + static constexpr u8 BitWidth = TestConfigT::bit_width; + static constexpr u32 BlockSize = TestConfigT::block_size; - TestSet testSet; + TestSet testSet; DecompressionTest64() : testSet(1u << 10) { } }; -template -[[nodiscard]] std::string testContext(const FixtureT& fixture, const uint32_t block) { +template +[[nodiscard]] std::string testContext(const FixtureT &fixture, const u32 block) { std::ostringstream message; - message << "bit_width=" << static_cast(FixtureT::BitWidth) << ", block_size=" << FixtureT::BlockSize - << ", block=" << block << ", scale=" << fixture.testSet.getScales()[block] - << ", tolerance=" << fixture.testSet.blockTolerance(block) << ", seed=" << fixture.testSet.getSeed(); + message << "bit_width=" << static_cast(FixtureT::BitWidth) << ", block_size=" << FixtureT::BlockSize + << ", block=" << block << ", scale=" << fixture.testSet.getScales()[block] + << ", tolerance=" << fixture.testSet.blockTolerance(block) << ", seed=" << fixture.testSet.getSeed(); return message.str(); } -template -void expectCompressedBlockEqualsReference(const FixtureT& fixture, const std::vector& actual, const uint32_t block) { +template +void expectCompressedBlockEqualsReference(const FixtureT &fixture, const std::vector &actual, + const u32 block) { SCOPED_TRACE(testContext(fixture, block)); - const auto& expected = fixture.testSet.getCompressedData()[block]; + const auto &expected = fixture.testSet.getCompressedData()[block]; ASSERT_EQ(actual.size(), expected.size()) << "Compressed block byte count differs from reference"; - for (uint32_t byte = 0; byte < actual.size(); byte++) { + for (u32 byte = 0; byte < actual.size(); byte++) { ASSERT_EQ(actual[byte], expected[byte]) - << "Compressed byte mismatch at byte=" << byte << ", actual=" << static_cast(actual[byte]) - << ", expected=" << static_cast(expected[byte]); + << "Compressed byte mismatch at byte=" << byte << ", actual=" << static_cast(actual[byte]) + << ", expected=" << static_cast(expected[byte]); } } -template -void expectDecompressedBlockNearSource(const FixtureT& fixture, const std::vector& actual, const uint32_t block) { +template +void expectDecompressedBlockNearSource(const FixtureT &fixture, const std::vector &actual, const u32 block) { SCOPED_TRACE(testContext(fixture, block)); - const auto& expected = fixture.testSet.getDecompressedData()[block]; + const auto &expected = fixture.testSet.getDecompressedData()[block]; ASSERT_EQ(actual.size(), expected.size()) << "Decompressed block element count differs from source"; - for (uint32_t element = 0; element < actual.size(); element++) { + for (u32 element = 0; element < actual.size(); element++) { ASSERT_NEAR(actual[element], expected[element], fixture.testSet.blockTolerance(block)) - << "Decompressed element mismatch at element=" << element << ", actual=" << actual[element] - << ", expected=" << expected[element] << ", absolute_error=" << std::abs(actual[element] - expected[element]); + << "Decompressed element mismatch at element=" << element << ", actual=" << actual[element] + << ", expected=" << expected[element] << ", absolute_error=" << std::abs( + actual[element] - expected[element]); } } TYPED_TEST_SUITE(CompressionTest, BitWidthBlockSizeTypes, BitWidthBlockSizeName); + TYPED_TEST_SUITE(DecompressionTest, BitWidthBlockSizeTypes, BitWidthBlockSizeName); + TYPED_TEST_SUITE(CompressionTest64, BitWidthBlockSizeTypes, BitWidthBlockSizeName); + TYPED_TEST_SUITE(DecompressionTest64, BitWidthBlockSizeTypes, BitWidthBlockSizeName); #endif // PERNIX_TESTSET_H diff --git a/tests/simd_tests.cpp b/tests/simd_tests.cpp new file mode 100644 index 0000000..019aea3 --- /dev/null +++ b/tests/simd_tests.cpp @@ -0,0 +1,188 @@ +#include + +#include +#include + +// --------------------------------------------------------------------------- +// SIMD compress: compress via backend, decompress via fallback, compare source +// --------------------------------------------------------------------------- + +template +void testBackendCompressBlock(FixtureT &fixture, pernix_backend backend) { + using T = std::remove_cvref_t; + + { + std::vector probe(FixtureT::BlockSize); + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_compress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[0].data(), + 1.0f / fixture.testSet.getScales()[0], probe.data()); + } else { + st = pernix_compress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[0].data(), + 1.0 / fixture.testSet.getScales()[0], probe.data()); + } + if (st != PERNIX_STATUS_OK) { + GTEST_SKIP(); + return; + } + } + + for (u32 b = 0; b < fixture.testSet.numberOfBlocks; b++) { + std::vector compressed(FixtureT::BlockSize); + + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_compress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[b].data(), + 1.0f / fixture.testSet.getScales()[b], compressed.data()); + } else { + st = pernix_compress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getDecompressedData()[b].data(), + 1.0 / fixture.testSet.getScales()[b], compressed.data()); + } + ASSERT_EQ(st, PERNIX_STATUS_OK); + + std::vector restored(fixture.testSet.elementsPerBlock); + if constexpr (std::is_same_v) { + st = pernix_decompress_block_f32(PERNIX_BACKEND_FALLBACK, FixtureT::BitWidth, FixtureT::BlockSize, + compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); + } else { + st = pernix_decompress_block_f64(PERNIX_BACKEND_FALLBACK, FixtureT::BitWidth, FixtureT::BlockSize, + compressed.data(), fixture.testSet.getScales()[b], restored.data(), true); + } + ASSERT_EQ(st, PERNIX_STATUS_OK); + + expectDecompressedBlockNearSource(fixture, restored, b); + } +} + +// --------------------------------------------------------------------------- +// SIMD decompress: decompress fallback-compressed data via backend, compare source +// --------------------------------------------------------------------------- + +template +void testBackendDecompressBlock(FixtureT &fixture, pernix_backend backend) { + using T = std::remove_cvref_t; + + { + std::vector probe(fixture.testSet.elementsPerBlock); + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_decompress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[0].data(), + fixture.testSet.getScales()[0], probe.data(), true); + } else { + st = pernix_decompress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[0].data(), + fixture.testSet.getScales()[0], probe.data(), true); + } + if (st != PERNIX_STATUS_OK) { + GTEST_SKIP(); + return; + } + } + + for (u32 b = 0; b < fixture.testSet.numberOfBlocks; b++) { + std::vector decompressed(fixture.testSet.elementsPerBlock); + + pernix_status st; + if constexpr (std::is_same_v) { + st = pernix_decompress_block_f32(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[b].data(), + fixture.testSet.getScales()[b], decompressed.data(), true); + } else { + st = pernix_decompress_block_f64(backend, FixtureT::BitWidth, FixtureT::BlockSize, + fixture.testSet.getCompressedData()[b].data(), + fixture.testSet.getScales()[b], decompressed.data(), true); + } + ASSERT_EQ(st, PERNIX_STATUS_OK); + + expectDecompressedBlockNearSource(fixture, decompressed, b); + } +} + +// --------------------------------------------------------------------------- +// x86: AVX2 +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, AVX2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +TYPED_TEST(DecompressionTest, AVX2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +TYPED_TEST(CompressionTest64, AVX2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +TYPED_TEST(DecompressionTest64, AVX2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX2); +} + +// --------------------------------------------------------------------------- +// x86: BMI2 +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, BMI2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +TYPED_TEST(DecompressionTest, BMI2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +TYPED_TEST(CompressionTest64, BMI2CompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +TYPED_TEST(DecompressionTest64, BMI2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_BMI2); +} + +// --------------------------------------------------------------------------- +// x86: AVX512-VBMI +// --------------------------------------------------------------------------- + +TYPED_TEST(CompressionTest, AVX512VBMICompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +TYPED_TEST(DecompressionTest, AVX512VBMIDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +TYPED_TEST(CompressionTest64, AVX512VBMICompressBlock) { + testBackendCompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +TYPED_TEST(DecompressionTest64, AVX512VBMIDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_X86_AVX512_VBMI); +} + +// --------------------------------------------------------------------------- +// ARM64: NEON (decompress only — no compress implementation) +// --------------------------------------------------------------------------- + +TYPED_TEST(DecompressionTest, NeonDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_NEON); +} + +TYPED_TEST(DecompressionTest64, NeonDecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_NEON); +} + +// --------------------------------------------------------------------------- +// ARM64: SVE2 (decompress only — no compress implementation) +// --------------------------------------------------------------------------- + +TYPED_TEST(DecompressionTest, SVE2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_SVE); +} + +TYPED_TEST(DecompressionTest64, SVE2DecompressBlock) { + testBackendDecompressBlock(*this, PERNIX_BACKEND_ARM64_SVE); +}