diff --git a/.clangd b/.clangd new file mode 100644 index 0000000..e69de29 diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index ae538ec..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "external/simde"] - path = external/simde - url = https://github.com/simd-everywhere/simde diff --git a/CMakeLists.txt b/CMakeLists.txt index edf6c78..1cbead6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,13 +13,67 @@ option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) +set(PERNIX_SIMDE_PROVIDER "AUTO" CACHE STRING "SIMDe provider when PERNIX_USE_SIMDE is enabled (AUTO, PACKAGE, FETCH)") +set_property(CACHE PERNIX_SIMDE_PROVIDER PROPERTY STRINGS AUTO PACKAGE FETCH) +set(PERNIX_ARCH_BACKEND "AUTO" CACHE STRING "Pernix architecture backend (AUTO, FALLBACK, X86, ARM64_NEON, ARM64_SVE, ARM64_SVE2)") +set_property(CACHE PERNIX_ARCH_BACKEND PROPERTY STRINGS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) option(PERNIX_ENABLE_FORTRAN_BINDINGS "Build Fortran bindings for pernix" off) +string(TOUPPER "${PERNIX_ARCH_BACKEND}" PERNIX_ARCH_BACKEND) +set(PERNIX_VALID_ARCH_BACKENDS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) +if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) + message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") +endif () + +set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_SELECTED_ARCH_BACKEND "X86") + elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") + else () + set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") + endif () +endif () +message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") + +string(TOUPPER "${PERNIX_SIMDE_PROVIDER}" PERNIX_SIMDE_PROVIDER) +set(PERNIX_VALID_SIMDE_PROVIDERS AUTO PACKAGE FETCH) +if (NOT PERNIX_SIMDE_PROVIDER IN_LIST PERNIX_VALID_SIMDE_PROVIDERS) + message(FATAL_ERROR "Unsupported PERNIX_SIMDE_PROVIDER='${PERNIX_SIMDE_PROVIDER}'. Expected one of: ${PERNIX_VALID_SIMDE_PROVIDERS}") +endif () + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") +set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL OFF) if (PERNIX_USE_SIMDE) - add_subdirectory(external/simde EXCLUDE_FROM_ALL) + if (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "PACKAGE") + find_package(simde CONFIG QUIET) + endif () + + if (NOT TARGET simde::simde AND (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "FETCH")) + include(FetchContent) + set(SIMDE_TEST_CMAKE_PACKAGING OFF CACHE BOOL "Test SIMDe CMake packaging" FORCE) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG f3e8262173b7089db9a9d57a9ecef8dd07ad9c97 + GIT_PROGRESS TRUE + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(simde) + set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL ON) + endif () + + if (NOT TARGET simde::simde AND DEFINED simde_SOURCE_DIR AND EXISTS "${simde_SOURCE_DIR}/simde") + add_library(simde::simde INTERFACE IMPORTED GLOBAL) + target_include_directories(simde::simde INTERFACE "${simde_SOURCE_DIR}") + endif () + + if (NOT TARGET simde::simde) + message(FATAL_ERROR "PERNIX_USE_SIMDE is enabled, but simde::simde was not found. Set PERNIX_SIMDE_PROVIDER=FETCH or install SIMDe's CMake package.") + endif () endif () include(CTest) @@ -40,28 +94,42 @@ else () endif () message(STATUS "Pernix version: ${VERSION}, normalized to ${NORMALIZED_VERSION}") -set(BENCHMARK_CXX_STANDARD 20) - -set(CMAKE_CXX_STANDARD ${BENCHMARK_CXX_STANDARD}) -set(CMAKE_CXX_STANDARD_REQUIRED YES) -set(CMAKE_CXX_EXTENSIONS OFF) - -include(AddCXXCompilerFlag) if (MSVC) message(FATAL_ERROR "MSVC compiler is not supported") else () - add_cxx_compiler_flag(-Wall) - add_cxx_compiler_flag(-Wextra) - add_cxx_compiler_flag(-Wshadow) - add_cxx_compiler_flag(-Wfloat-equal) - add_cxx_compiler_flag(-Wold-style-cast) - add_cxx_compiler_flag(-Wconversion) - add_cxx_compiler_flag(-fstrict-aliasing) - add_cxx_compiler_flag(-Wno-ignored-attributes) + include(CheckCXXCompilerFlag) + set(PERNIX_PRIVATE_COMPILE_OPTIONS) + foreach (PERNIX_CXX_FLAG + -Wall + -Wextra + -Wshadow + -Wfloat-equal + -Wold-style-cast + -Wconversion + -fstrict-aliasing + -Wno-ignored-attributes + ) + string(MAKE_C_IDENTIFIER "PERNIX_HAS_CXX_FLAG_${PERNIX_CXX_FLAG}" PERNIX_CXX_FLAG_VARIABLE) + check_cxx_compiler_flag("${PERNIX_CXX_FLAG}" "${PERNIX_CXX_FLAG_VARIABLE}") + if (${PERNIX_CXX_FLAG_VARIABLE}) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "${PERNIX_CXX_FLAG}") + else () + message(STATUS "Compiler flag not supported: ${PERNIX_CXX_FLAG}") + endif () + endforeach () if (PERNIX_ENABLE_LTO) - add_cxx_compiler_flag(-flto=auto) - add_cxx_compiler_flag(-Wno-lto-type-mismatch) + include(CheckIPOSupported) + check_ipo_supported(RESULT PERNIX_IPO_SUPPORTED OUTPUT PERNIX_IPO_ERROR) + if (NOT PERNIX_IPO_SUPPORTED) + message(FATAL_ERROR "PERNIX_ENABLE_LTO is enabled, but IPO/LTO is not supported: ${PERNIX_IPO_ERROR}") + endif () + + check_cxx_compiler_flag("-Wno-lto-type-mismatch" PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + if (PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "-Wno-lto-type-mismatch") + endif () + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") find_program(GCC_AR gcc-ar) if (GCC_AR) @@ -84,8 +152,6 @@ else () endif () endif () -include_directories(${PROJECT_SOURCE_DIR}/include) - add_subdirectory(src) if (PERNIX_ENABLE_FORTRAN_BINDINGS) @@ -97,4 +163,4 @@ endif () if (PERNIX_ENABLE_TESTS) enable_testing() add_subdirectory(tests) -endif () \ No newline at end of file +endif () diff --git a/external/simde b/external/simde deleted file mode 160000 index 1747b24..0000000 --- a/external/simde +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1747b2482589fe894d49989159421da08c2a8bcd diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h new file mode 100644 index 0000000..8e517fa --- /dev/null +++ b/include/pernix/arm64/neon/common.h @@ -0,0 +1,200 @@ +#ifndef PERNIX_ARM64_NEON_COMMON_H +#define PERNIX_ARM64_NEON_COMMON_H + +#include + +#include + +namespace pernix::arm64::neon::internal { +struct float64x2x8_t { + float64x2_t val[8]; +}; + +static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; +} + +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x4(const int8x16_t& input) { + const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); + const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); + + return {{ + vmovl_s16(vget_low_s16(s16_lo)), + vmovl_s16(vget_high_s16(s16_lo)), + vmovl_s16(vget_low_s16(s16_hi)), + vmovl_s16(vget_high_s16(s16_hi)), + }}; +} + +__always_inline int32x4x2_t neon_convert_int16x8_int32x4x2(const int16x8_t& input) { + return {{ + vmovl_s16(vget_low_s16(input)), + vmovl_s16(vget_high_s16(input)), + }}; +} + +__always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, const float32x4_t& scale) { + return {{ + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[2]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[3]), scale), + }}; +} + +__always_inline float32x4x2_t neon_dequantize_epi32(const int32x4x2_t& input, const float32x4_t& scale) { + return {{ + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + }}; +} + +__always_inline float32x4_t neon_dequantize_epi32(const int32x4_t& input, const float32x4_t& scale) { + return vmulq_f32(vcvtq_f32_s32(input), scale); +} + +__always_inline float64x2_t neon_dequantize_epi32_f64(const int32x2_t& input, const float64x2_t& scale) { + return vmulq_f64(vcvtq_f64_s64(vmovl_s32(input)), scale); +} + +__always_inline float64x2x2_t neon_dequantize_epi32_f64(const int32x4_t& input, const float64x2_t& scale) { + return {{ + neon_dequantize_epi32_f64(vget_low_s32(input), scale), + neon_dequantize_epi32_f64(vget_high_s32(input), scale), + }}; +} + +__always_inline float64x2x4_t neon_dequantize_epi32_f64(const int32x4x2_t& input, const float64x2_t& scale) { + const float64x2x2_t dequantized_low = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized_high = neon_dequantize_epi32_f64(input.val[1], scale); + + return {{ + dequantized_low.val[0], + dequantized_low.val[1], + dequantized_high.val[0], + dequantized_high.val[1], + }}; +} + +__always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t& input, const float64x2_t& scale) { + const float64x2x2_t dequantized0 = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized1 = neon_dequantize_epi32_f64(input.val[1], scale); + const float64x2x2_t dequantized2 = neon_dequantize_epi32_f64(input.val[2], scale); + const float64x2x2_t dequantized3 = neon_dequantize_epi32_f64(input.val[3], scale); + + return {{ + dequantized0.val[0], + dequantized0.val[1], + dequantized1.val[0], + dequantized1.val[1], + dequantized2.val[0], + dequantized2.val[1], + dequantized3.val[0], + dequantized3.val[1], + }}; +} + +__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_bytes_count) { + uint8_t buffer[16] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u8(buffer); +} + +__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_bytes_count) { + uint16_t buffer[8] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u16(buffer); +} + +__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_bytes_count) { + uint32_t buffer[4] = {0}; + std::memcpy(buffer, input, tail_bytes_count); + return vld1q_u32(buffer); +} + +__always_inline float32x4_t neon_load_tail_elements_f32(const uint8_t* input, const uint32_t tail_elements) { + float32_t buffer[4] = {0.0f}; + std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); + return vld1q_f32(buffer); +} + +__always_inline float64x2_t neon_load_tail_elements_f64(const uint8_t* input, const uint32_t tail_elements) { + float64_t buffer[2] = {0.0}; + std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); + return vld1q_f64(buffer); +} + +__always_inline void neon_store_tail_elements_int8(uint8_t* output, const uint8x16x4_t& data, const uint32_t tail_elements) { + uint8_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u8(buffer + i * 16, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint8_t)); +} + +__always_inline void neon_store_tail_elements_int16(uint16_t* output, const uint16x8x4_t& data, const uint32_t tail_elements) { + uint16_t buffer[8 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u16(buffer + i * 8, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint16_t)); +} + +__always_inline void neon_store_tail_elements_int32(uint32_t* output, const uint32x4x4_t& data, const uint32_t tail_elements) { + uint32_t buffer[4 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x4_t& data, const uint32_t tail_elements) { + float32_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x2_t& data, const uint32_t tail_elements) { + float32_t buffer[8 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4_t& data, const uint32_t tail_elements) { + float32_t buffer[4]; + vst1q_f32(buffer, data); + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x2_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x8_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 8]; + for (uint32_t i = 0; i < 8; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_COMMON_H diff --git a/include/pernix/arm64/neon/compression.h b/include/pernix/arm64/neon/compression.h new file mode 100644 index 0000000..6e49348 --- /dev/null +++ b/include/pernix/arm64/neon/compression.h @@ -0,0 +1,141 @@ +#ifndef PERNIX_ARM64_NEON_COMPRESSION_H +#define PERNIX_ARM64_NEON_COMPRESSION_H + +#include +#include + +#include +#include + +namespace pernix::arm64::neon { +namespace internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int neon_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int neon_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int neon_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); + +int neon_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix::arm64::neon + +#endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h new file mode 100644 index 0000000..583948f --- /dev/null +++ b/include/pernix/arm64/neon/decompression.h @@ -0,0 +1,353 @@ +#ifndef PERNIX_ARM64_NEON_DECOMPRESSION_H +#define PERNIX_ARM64_NEON_DECOMPRESSION_H + +#include +#include +#include + +#include +#include + +namespace pernix::arm64::neon { +namespace internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); + const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); + const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_8; ++i) { + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_4; ++i) { + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + unpacked = b128::neon_unpack_epi32_17to24(source); + } + } + + const float32x4_t dequantized = neon_dequantize_epi32(unpacked, scale_v); + + vst1q_f32(output, dequantized); + + output += 4; + } + + if constexpr (remaining_elements > 0) { + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + } + + const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); + const float64x2x8_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 8; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); + const float64x2x8_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_8; ++i) { + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_4; ++i) { + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi32_17to24(source); + } else { + unpacked = b128::neon_unpack_epi32_17to24(source); + } + } + + const float64x2x2_t dequantized = neon_dequantize_epi32_f64(unpacked, scale_v); + + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + } + + if constexpr (remaining_elements > 0) { + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); + } + + const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); + } + + return 0; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + +int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); + +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix::arm64::neon + +#endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/packing.h b/include/pernix/arm64/neon/packing.h new file mode 100644 index 0000000..538b5a8 --- /dev/null +++ b/include/pernix/arm64/neon/packing.h @@ -0,0 +1,9 @@ +#ifndef PERNIX_ARM64_NEON_PACKING_H +#define PERNIX_ARM64_NEON_PACKING_H + +#include + +namespace pernix::arm64::neon::internal { +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_PACKING_H diff --git a/include/pernix/arm64/neon/tables.h b/include/pernix/arm64/neon/tables.h new file mode 100644 index 0000000..d085551 --- /dev/null +++ b/include/pernix/arm64/neon/tables.h @@ -0,0 +1,212 @@ +#ifndef PERNIX_ARM64_NEON_TABLES_H +#define PERNIX_ARM64_NEON_TABLES_H + +#include +#include +#include +#include + +namespace pernix::arm64::neon::internal { +namespace detail { +inline constexpr std::size_t neon_vector_width = 128; +inline constexpr uint8_t inactive_lane = 0xff; + +template +constexpr bool table_indices_are_valid(const std::array& table) { + return std::ranges::all_of(table, [](const uint8_t index) { + return index == inactive_lane || index < Elements; + }); +} + +template +constexpr std::array make_primary_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t base = entry * lane_bytes; + + for (std::size_t lane_byte = 0; lane_byte < lane_bytes; ++lane_byte) { + table[base + lane_byte] = static_cast(first_byte + lane_byte); + } + } + + return table; +} + +template +constexpr std::array make_spill_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t base = entry * lane_bytes; + + if (bit_offset + BIT_WIDTH > LANE_BITS) { + table[base] = static_cast(first_byte + lane_bytes); + } + } + + return table; +} + +template +constexpr std::array make_shift_right() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + + table[entry] = -static_cast(bit_offset); + } + + return table; +} + +template +constexpr std::array make_shift_left_for_spill() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + const bool spills = bit_offset + BIT_WIDTH > LANE_BITS; + + table[entry] = spills ? static_cast(LANE_BITS - bit_offset) : 0; + } + + return table; +} + +template +constexpr std::array make_contiguous_permute_32() { + static_assert(ELEMENTS % 4 == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / 4; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + const std::size_t bit_end = bit_start + BIT_WIDTH - 1; + const std::size_t first_byte = bit_start / 8; + const std::size_t last_byte = bit_end / 8; + const std::size_t base = entry * 4; + + for (std::size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); + } + } + + return table; +} + +template +constexpr std::array make_shift_right_32() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + + table[entry] = -static_cast(bit_start % 8u); + } + + return table; +} +} // namespace detail + +template +struct table_unpacking; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && VECTOR_WIDTH == detail::neon_vector_width) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 8; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 16); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && VECTOR_WIDTH == detail::neon_vector_width) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 16; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = + detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 8); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && VECTOR_WIDTH == detail::neon_vector_width && START_BIT_OFFSET < 8) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 32; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute = + detail::make_contiguous_permute_32(); + alignas(64) static constexpr std::array shift = + detail::make_shift_right_32(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 4); + static_assert(detail::table_indices_are_valid(permute)); +}; + +template +struct table_packing; +} // namespace pernix::arm64::internal + +#endif // PERNIX_ARM64_NEON_TABLES_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h new file mode 100644 index 0000000..6ac0e20 --- /dev/null +++ b/include/pernix/arm64/neon/unpacking.h @@ -0,0 +1,100 @@ +#ifndef PERNIX_ARM64_NEON_UNPACKING_H +#define PERNIX_ARM64_NEON_UNPACKING_H + +#include +#include + +using namespace pernix::arm64::neon::internal; + +namespace pernix::arm64::neon::internal::b128 { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { + if constexpr (BIT_WIDTH == 8) { + return vreinterpretq_s8_u8(input); + } else if constexpr (BIT_WIDTH == 1) { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + const uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + return vreinterpretq_s8_u8(vandq_u8(shifted, vdupq_n_u8(1))); + } else { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + + uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute2.data())); + + shifted = vorrq_u8(shifted, vshlq_u8(permuted2_u8, vld1q_s8(tables::shift2.data()))); + } + + constexpr int shift = 8 - BIT_WIDTH; + shifted = vshlq_n_u8(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s8(vreinterpretq_s8_u8(shifted), vdupq_n_s8(-shift)); + } else { + return vreinterpretq_s8_u8(vshlq_u8(shifted, vdupq_n_s8(-shift))); + } + } +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t& input) { + if constexpr (BIT_WIDTH == 16) { + return vreinterpretq_s16_u16(input); + } else { + using tables = table_unpacking; + + const uint8x16_t input_u8 = vreinterpretq_u8_u16(input); + + const uint8x16_t permuted1_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute1.data())); + + uint16x8_t shifted = vshlq_u16(vreinterpretq_u16_u8(permuted1_u8), vld1q_s16(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute2.data())); + + const uint16x8_t shifted2 = vshlq_u16(vreinterpretq_u16_u8(permuted2_u8), vld1q_s16(tables::shift2.data())); + + shifted = vorrq_u16(shifted, shifted2); + } + + constexpr int shift = 16 - BIT_WIDTH; + shifted = vshlq_n_u16(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s16(vreinterpretq_s16_u16(shifted), vdupq_n_s16(-shift)); + } else { + return vreinterpretq_s16_u16(vshlq_u16(shifted, vdupq_n_s16(-shift))); + } + } +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t& input) { + using tables = table_unpacking; + + const uint8x16_t input_8 = vreinterpretq_u8_u32(input); + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input_8, vld1q_u8(tables::permute.data())); + + const uint32x4_t value = vshlq_u32(vreinterpretq_u32_u8(permuted_u8), vld1q_s32(tables::shift.data())); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); + } +} +} // namespace pernix::arm64::neon::internal::b128 + +#endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve/compression.h b/include/pernix/arm64/sve/compression.h new file mode 100644 index 0000000..cf83ce0 --- /dev/null +++ b/include/pernix/arm64/sve/compression.h @@ -0,0 +1,141 @@ +#ifndef PERNIX_ARM64_SVE_COMPRESSION_H +#define PERNIX_ARM64_SVE_COMPRESSION_H + +#include +#include + +#include +#include + +namespace pernix::arm64::sve { +namespace internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_compress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_compress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int sve_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int sve_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); + +int sve_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix::arm64::sve + +#endif // PERNIX_ARM64_SVE_COMPRESSION_H diff --git a/include/pernix/arm64/sve/decompression.h b/include/pernix/arm64/sve/decompression.h new file mode 100644 index 0000000..052a3e4 --- /dev/null +++ b/include/pernix/arm64/sve/decompression.h @@ -0,0 +1,141 @@ +#ifndef PERNIX_ARM64_SVE_DECOMPRESSION_H +#define PERNIX_ARM64_SVE_DECOMPRESSION_H + +#include +#include + +#include +#include + +namespace pernix::arm64::sve { +namespace internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_decompress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_decompress_block_17to24(input, scale, output); + } + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int sve_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); + +int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix::arm64::sve + +#endif // PERNIX_ARM64_SVE_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve/packing.h b/include/pernix/arm64/sve/packing.h new file mode 100644 index 0000000..ab57b4f --- /dev/null +++ b/include/pernix/arm64/sve/packing.h @@ -0,0 +1,9 @@ +#ifndef PERNIX_ARM64_SVE_PACKING_H +#define PERNIX_ARM64_SVE_PACKING_H + +#include + +namespace pernix::arm64::sve::internal { +} // namespace pernix::arm64::sve::internal + +#endif // PERNIX_ARM64_SVE_PACKING_H diff --git a/include/pernix/arm64/sve/unpacking.h b/include/pernix/arm64/sve/unpacking.h new file mode 100644 index 0000000..2565ab7 --- /dev/null +++ b/include/pernix/arm64/sve/unpacking.h @@ -0,0 +1,9 @@ +#ifndef PERNIX_ARM64_SVE_UNPACKING_H +#define PERNIX_ARM64_SVE_UNPACKING_H + +#include + +namespace pernix::arm64::sve::internal { +} // namespace pernix::arm64::sve::internal + +#endif // PERNIX_ARM64_SVE_UNPACKING_H diff --git a/include/pernix/arm64/sve2/compression.h b/include/pernix/arm64/sve2/compression.h new file mode 100644 index 0000000..4e4627d --- /dev/null +++ b/include/pernix/arm64/sve2/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE2_COMPRESSION_H +#define PERNIX_ARM64_SVE2_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve2_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve2_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int sve2_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int sve2_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int sve2_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE2_COMPRESSION_H diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h new file mode 100644 index 0000000..2128ff2 --- /dev/null +++ b/include/pernix/arm64/sve2/decompression.h @@ -0,0 +1,437 @@ +#ifndef PERNIX_ARM64_SVE2_DECOMPRESSION_H +#define PERNIX_ARM64_SVE2_DECOMPRESSION_H + +#include +#include +#include + +#include +#include +#include +#include + +namespace pernix::arm64::sve2 { +namespace internal { +template +[[nodiscard]] __always_inline constexpr uint32_t packed_bytes(const uint32_t elements) { + return (elements * BIT_WIDTH + 7) / 8; +} + +[[nodiscard]] __always_inline svuint8_t sve2_load_packed_bytes(const uint8_t* __restrict__ input, const uint32_t bytes) { + const svbool_t pg = svwhilelt_b8(uint64_t{0}, static_cast(bytes)); + return svld1_u8(pg, input); +} + +template +__always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sb_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1ub_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sh_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1uh_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + const svbool_t pg = svwhilelt_b32(uint64_t{0}, static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, values), scale_v); + } else { + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, svreinterpret_u32_s32(values)), scale_v); + } + + svst1_f32(pg, output, dequantized); +} + +template +__always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntw()); + + svst1_s32(svptrue_b32(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(input, scale, output); + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(input, scale, output); + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve2_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + +int sve2_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + +int sve2_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); + +int sve2_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix::arm64::sve2 + +#endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve2/packing.h b/include/pernix/arm64/sve2/packing.h new file mode 100644 index 0000000..789b4d7 --- /dev/null +++ b/include/pernix/arm64/sve2/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE2_PACKING_H +#define PERNIX_ARM64_SVE2_PACKING_H + +#include + +namespace pernix::arm64::sve2::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_PACKING_H diff --git a/include/pernix/arm64/sve2/tables.h b/include/pernix/arm64/sve2/tables.h new file mode 100644 index 0000000..897fa9b --- /dev/null +++ b/include/pernix/arm64/sve2/tables.h @@ -0,0 +1,118 @@ +#ifndef PERNIX_ARM64_SVE2_TABLES_H +#define PERNIX_ARM64_SVE2_TABLES_H + +#include + +#include + +namespace pernix::arm64::sve2::internal { +template +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svbool_t pg_b8() { return svptrue_b8(); } + + static svbool_t pg_b16() { return svptrue_b16(); } + + static svbool_t pg_b32() { return svptrue_b32(); } +}; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 1); + } + + static svuint8_t shift() { + const svbool_t pg = svptrue_b8(); + return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); + } + + static svuint8_t spill_shift() { + const svbool_t pg = svptrue_b8(); + return svsub_u8_x(pg, svdup_n_u8(8), shift()); + } +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); + const svuint8_t byte = svand_n_u8_x(pg, lane, 1); + + svuint8_t first; + if constexpr (BIT_WIDTH == 16) { + first = svlsl_n_u8_x(pg, elem, 1); + } else { + constexpr uint8_t extra_bits = BIT_WIDTH - 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); + first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); + } + + return svadd_u8_x(pg, first, byte); + } + + static svuint8_t spill_permute() { + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 2); + } + + static svuint16_t shift() { + const svbool_t pg = svptrue_b16(); + return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); + } + + static svuint16_t spill_shift() { + const svbool_t pg = svptrue_b16(); + const svuint16_t bit_shift = shift(); + const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); + return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); + } +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); + const svuint8_t byte = svand_n_u8_x(pg, lane, 3); + + svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); + if constexpr (BIT_WIDTH % 8u != 0) { + constexpr uint8_t extra_bits = BIT_WIDTH % 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low_bits = + svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); + first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); + } + + return svadd_u8_x(pg, first, byte); + } + + static svuint32_t shift() { + const svbool_t pg = svptrue_b32(); + return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), START_BIT_OFFSET), 7); + } +}; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/include/pernix/arm64/sve2/unpacking.h b/include/pernix/arm64/sve2/unpacking.h new file mode 100644 index 0000000..326901f --- /dev/null +++ b/include/pernix/arm64/sve2/unpacking.h @@ -0,0 +1,92 @@ +#ifndef PERNIX_ARM64_SVE2_UNPACKING_H +#define PERNIX_ARM64_SVE2_UNPACKING_H + +#include + +#include "tables.h" + +namespace pernix::arm64::sve2::internal { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, + const svuint8_t spill_permute, const svuint8_t spill_shift) { + if constexpr (BIT_WIDTH == 8) { + return svreinterpret_s8(input); + } else { + const svbool_t pg = svptrue_b8(); + + const svuint8_t permuted = svtbl_u8(input, permute); + svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); + const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); + unpacked = svorr_u8_x(pg, unpacked, spill_shifted); + } + + if constexpr (BIT_WIDTH == 1) { + unpacked = svand_n_u8_x(pg, unpacked, 1); + return svreinterpret_s8(unpacked); + } else { + constexpr int sign_shift = 8 - BIT_WIDTH; + + unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); + } else { + return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); + } + } + } +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, const svuint16_t shift, + const svuint8_t spill_permute, const svuint16_t spill_shift) { + if constexpr (BIT_WIDTH == 16) { + return svreinterpret_s16(input); + } else { + const svbool_t pg = svptrue_b16(); + + const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); + svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); + const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), spill_shift); + shifted = svorr_u16_x(pg, shifted, spill_shifted); + } + + constexpr int sign_shift = 16 - BIT_WIDTH; + shifted = svlsl_n_u16_x(pg, shifted, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); + } else { + return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); + } + } +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { + using table = table_unpacking; + + const svbool_t pg = svptrue_b32(); + const svuint8_t permuted = svtbl_u8(input, table::permute()); + const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + } +} +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/include/pernix/detection.h b/include/pernix/detection.h index edecb6c..fa9cd44 100644 --- a/include/pernix/detection.h +++ b/include/pernix/detection.h @@ -10,6 +10,19 @@ #define PERNIX_MACHINE_ID_V4 3 #define PERNIX_MACHINE_ID_V4_VBMI 4 +#if defined(PERNIX_BACKEND_ARM64_NEON) +#define PERNIX_ARM64_NEON_ENABLED +#endif + +#if defined(PERNIX_BACKEND_ARM64_SVE) +#define PERNIX_ARM64_SVE_ENABLED +#endif + +#if defined(PERNIX_BACKEND_ARM64_SVE2) +#define PERNIX_ARM64_SVE2_ENABLED +#endif + +#if defined(PERNIX_BACKEND_X86) // Map the compiler's enabled ISA set to the highest supported Pernix target level. #if (__SSE3__ && __SSE4_1__ && __SSE4_2__) #if (__AVX__ && __AVX2__ && __FMA__ && __BMI__ && __BMI2__) @@ -32,6 +45,10 @@ #define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC #endif +#else +#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC +#endif + // Feature-selection macros consumed by the public headers. #if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V2) #define PERNIX_SSE_ENABLED @@ -47,7 +64,7 @@ #define PERNIX_AVX512_VBMI_ENABLED #endif -#ifdef PERNIX_USE_SIMDE +#if defined(PERNIX_USE_SIMDE) && defined(PERNIX_BACKEND_X86) #define PERNIX_SSE_ENABLED #define PERNIX_AVX2_ENABLED #define PERNIX_BMI2_ENABLED diff --git a/include/pernix/pernix.h b/include/pernix/pernix.h index 55133bb..4998a60 100644 --- a/include/pernix/pernix.h +++ b/include/pernix/pernix.h @@ -5,7 +5,7 @@ // Include architecture-specific headers based on detected capabilities // AVX2 -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #include #include @@ -21,7 +21,22 @@ #include #endif // PERNIX_AVX512_VBMI_ENABLED -#endif // PERNIX_AVX2_ENABLED +#endif // PERNIX_BACKEND_X86 && PERNIX_AVX2_ENABLED + +#ifdef PERNIX_BACKEND_ARM64_NEON +#include +#include +#endif + +#ifdef PERNIX_BACKEND_ARM64_SVE +#include +#include +#endif + +#ifdef PERNIX_BACKEND_ARM64_SVE2 +#include +#include +#endif // Fallback (non-SIMD) implementations #include @@ -167,7 +182,7 @@ template int decompress_blocks(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); // Use the best available implementation based on detected CPU features at compile time. -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #ifdef PERNIX_AVX512_VBMI_ENABLED template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -265,6 +280,150 @@ int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, d return mm256_decompress_blocks_avx2(input, scale, output, blocks); } #endif +#elif defined(PERNIX_BACKEND_ARM64_NEON) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return neon_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return neon_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return neon_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return neon_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return neon_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return neon_decompress_blocks(input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return sve_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return sve_decompress_blocks(input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE2) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve2_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve2_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve2_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve2_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return sve2_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return sve2_decompress_blocks(input, scale, output, blocks); +} #else template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -420,4 +579,4 @@ int decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, } // namespace pernix #endif -#endif // PERNIX_H \ No newline at end of file +#endif // PERNIX_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index c95c1ee..509eb13 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -7,10 +7,19 @@ #if defined(PERNIX_USE_SIMDE) #define SIMDE_ENABLE_NATIVE_ALIASES #undef SIMDE_X86_AVX512FP16_NATIVE +#if defined(__clang__) +#define SIMDE_X86_AVX512BF16_NATIVE +#endif // #define SIMDE_NO_NATIVE +#if defined(PERNIX_BACKEND_X86) #include -#include #include +#include +#elif defined(PERNIX_BACKEND_ARM64_NEON) +#include +#elif defined(PERNIX_BACKEND_ARM64_SVE) || defined(PERNIX_BACKEND_ARM64_SVE2) +#include +#endif // #ifndef __mmask8 // typedef uint8_t __mmask8; @@ -27,6 +36,13 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include +#elif defined(__aarch64__) || defined(__arm64ec__) +#ifdef __ARM_FEATURE_SVE +#include +#endif +#ifdef __ARM_NEON +#include +#endif #endif #ifndef __always_inline @@ -39,25 +55,4 @@ #endif #endif -template - requires(std::is_integral_v && sizeof(T) <= 8) -static constexpr T tail_mask(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - if (tail_bytes == 0u) { - return static_cast(0); - } - if (tail_bytes >= 64u) { - return static_cast(~uint64_t{0}); - } - const uint64_t mask = (uint64_t{1} << tail_bytes) - 1u; - return static_cast(mask); -} - -static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - return tail_bytes; -} - #endif // PERNIX_SIMD_COMPAT_H diff --git a/include/pernix/x86/avx512vbmi/compression.h b/include/pernix/x86/avx512vbmi/compression.h index bc5a375..9621e12 100644 --- a/include/pernix/x86/avx512vbmi/compression.h +++ b/include/pernix/x86/avx512vbmi/compression.h @@ -1,13 +1,16 @@ #ifndef PERNIX_AVX512VBMI_COMPRESSION_H #define PERNIX_AVX512VBMI_COMPRESSION_H +#include #include -#include #include -#include +#include +#include #include +using namespace pernix::x86::internal; + namespace pernix { namespace internal { template @@ -180,8 +183,8 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -250,8 +253,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -288,7 +291,7 @@ template if constexpr (remaining_elements > 0) { const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); } @@ -298,8 +301,8 @@ template template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_64 = elements_per_block / 64; @@ -413,8 +416,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -495,8 +498,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/include/pernix/x86/avx512vbmi/decompression.h index 61280c9..08abc35 100644 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ b/include/pernix/x86/avx512vbmi/decompression.h @@ -1,32 +1,35 @@ #ifndef PERNIX_AVX512VBMI_DECOMPRESSION_H #define PERNIX_AVX512VBMI_DECOMPRESSION_H +#include #include -#include #include -#include +#include +#include #include +using namespace pernix::x86::internal; + namespace pernix { namespace internal { /** * @brief Dequantize sixteen integer values to floats. */ -[[gnu::always_inline]] inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { +__always_inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { const __m512 converted = _mm512_cvtepi32_ps(input); return _mm512_mul_ps(converted, scale); } -[[gnu::always_inline]] inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { +__always_inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { const __m512d converted = _mm512_cvtepi64_pd(input); return _mm512_mul_pd(converted, scale); } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const uint32_t iterations_64 = elements_per_block / 64; @@ -109,8 +112,8 @@ template template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const uint32_t iterations_64 = elements_per_block / 64; @@ -227,8 +230,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -300,8 +303,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -384,8 +387,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -436,8 +439,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -603,8 +606,8 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const } } // namespace pernix -#ifdef __cplusplus namespace pernix { +#ifdef __cplusplus extern "C" { #endif /** @@ -666,7 +669,7 @@ int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __r #ifdef __cplusplus } -} // namespace pernix #endif +} // namespace pernix #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/packing.h b/include/pernix/x86/avx512vbmi/packing.h index c9f9db9..ba3b132 100644 --- a/include/pernix/x86/avx512vbmi/packing.h +++ b/include/pernix/x86/avx512vbmi/packing.h @@ -11,7 +11,7 @@ namespace m128 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { +__always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -48,7 +48,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { +__always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -93,7 +93,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { +__always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { using tables = pack_tables_avx512_24; const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); @@ -117,7 +117,7 @@ namespace m256 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { +__always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -154,7 +154,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { +__always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -199,7 +199,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { +__always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { using tables = pack_tables_avx512_24; const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); @@ -223,7 +223,7 @@ namespace m512 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { +__always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -260,7 +260,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { +__always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -305,7 +305,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { +__always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { using tables = pack_tables_avx512_24; const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); diff --git a/include/pernix/x86/avx512vbmi/tables.h b/include/pernix/x86/avx512vbmi/tables.h index 9115625..4d66727 100644 --- a/include/pernix/x86/avx512vbmi/tables.h +++ b/include/pernix/x86/avx512vbmi/tables.h @@ -9,9 +9,8 @@ #include namespace pernix::internal { - template -[[gnu::always_inline]] static inline Vec load_table(const std::array& table) { +static __always_inline Vec load_table(const std::array& table) { static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); if constexpr (std::is_same_v) { return _mm512_load_si512(static_cast(table.data())); @@ -529,13 +528,13 @@ struct pack_tables_avx512_16 { // clang-format on } - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; template @@ -591,7 +590,7 @@ struct pack_tables_avx512_24 { return plan; } - static inline constexpr std::array word_plans = [] { + static constexpr std::array word_plans = [] { std::array plans{}; for (uint32_t i = 0; i < 16; ++i) { plans[i] = create_plan(i); @@ -600,7 +599,7 @@ struct pack_tables_avx512_24 { }(); template - [[gnu::always_inline]] static constexpr std::array make_table(Getter getter) { + static __always_inline constexpr std::array make_table(Getter getter) { std::array values{}; for (uint32_t i = 0; i < 16; ++i) { values[i] = getter(word_plans[i]); @@ -608,26 +607,26 @@ struct pack_tables_avx512_24 { return values; } - alignas(64) static inline constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); + alignas(64) static constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); - alignas(64) static inline constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); + alignas(64) static constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); - alignas(64) static inline constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); + alignas(64) static constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); - alignas(64) static inline constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); + alignas(64) static constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); - alignas(64) static inline constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); + alignas(64) static constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); - alignas(64) static inline constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); + alignas(64) static constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; template @@ -693,11 +692,11 @@ struct unpack_tables_avx512_8 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; template @@ -768,11 +767,11 @@ struct unpack_tables_avx512_16 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; template @@ -813,10 +812,9 @@ struct unpack_tables_avx512_24 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute() { return load_table(permute); } - [[gnu::always_inline]] static inline Vec get_shift() { return load_table(shift); } + static __always_inline Vec get_permute() { return load_table(permute); } + static __always_inline Vec get_shift() { return load_table(shift); } }; - -} // namespace pernix::internal +} // namespace pernix::internal #endif // PERNIX_AVX512VBMI_TABLES_H diff --git a/include/pernix/x86/utils.h b/include/pernix/x86/utils.h new file mode 100644 index 0000000..185e0a2 --- /dev/null +++ b/include/pernix/x86/utils.h @@ -0,0 +1,16 @@ +#ifndef PERNIX_X86_UTILS_H +#define PERNIX_X86_UTILS_H + +#include + +namespace pernix::x86::internal { + +static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; +} + +} // namespace pernix::x86::internal + +#endif // PERNIX_X86_UTILS_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b7e8f23..57d2f0a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,9 @@ include(GNUInstallDirs) +include(CMakePackageConfigHelpers) file(GLOB_RECURSE PERNIX_COMMON_SOURCES + CONFIGURE_DEPENDS ./fallback/*.cpp ./pernix.cpp ${PROJECT_SOURCE_DIR}/include/pernix/*.h @@ -9,34 +11,73 @@ file(GLOB_RECURSE set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) -set(PERNIX_TARGET_IS_X86 OFF) -if (PERNIX_USE_SIMDE OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_TARGET_IS_X86 ON) -endif () - -if (PERNIX_TARGET_IS_X86) +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES + CONFIGURE_DEPENDS ./x86/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_X86_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_SOURCES + CONFIGURE_DEPENDS + ./arm64/neon/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/neon/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_NEON_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_SOURCES + CONFIGURE_DEPENDS + ./arm64/sve/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_SOURCES + CONFIGURE_DEPENDS + ./arm64/sve2/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve2/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE2_SOURCES}) endif () add_library(pernix SHARED ${PERNIX_SOURCES}) +add_library(pernix::pernix ALIAS pernix) set_target_properties(pernix PROPERTIES OUTPUT_NAME "pernix" VERSION ${NORMALIZED_VERSION} ) +target_compile_features(pernix PUBLIC cxx_std_20) +target_compile_options(pernix PRIVATE ${PERNIX_PRIVATE_COMPILE_OPTIONS}) target_include_directories(pernix PUBLIC $ + $ ) +if (PERNIX_ENABLE_LTO) + set_target_properties(pernix PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) +endif () if (PERNIX_USE_SIMDE) target_link_libraries(pernix PUBLIC simde::simde) target_compile_definitions(pernix PUBLIC PERNIX_USE_SIMDE=1) endif () +target_compile_definitions(pernix PUBLIC "PERNIX_BACKEND_${PERNIX_SELECTED_ARCH_BACKEND}=1") + +if (PERNIX_DISABLE_BMI2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_BMI2=1) +endif () +if (PERNIX_DISABLE_AVX2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX2=1) +endif () +if (PERNIX_DISABLE_AVX512) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX512=1) +endif () + set_target_properties(pernix PROPERTIES LINKER_LANGUAGE CXX) configure_file( @@ -45,23 +86,50 @@ configure_file( if (PERNIX_ENABLE_INSTALL) install(TARGETS pernix + EXPORT pernixTargets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) - install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/pernix ${PROJECT_BINARY_DIR}/include/pernix + install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/pernix" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.*h" ) + if (PERNIX_USE_SIMDE AND PERNIX_BUNDLE_SIMDE_FOR_INSTALL) + install(DIRECTORY "${simde_SOURCE_DIR}/simde" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + FILES_MATCHING PATTERN "*.h" + ) + endif () + + configure_package_config_file( + "${PROJECT_SOURCE_DIR}/cmake/pernixConfig.cmake.in" + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + write_basic_package_version_file( + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + VERSION ${NORMALIZED_VERSION} + COMPATIBILITY SameMajorVersion + ) + install( + FILES + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + install( + EXPORT pernixTargets + FILE pernixTargets.cmake + NAMESPACE pernix:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) install( FILES ${PROJECT_BINARY_DIR}/pernix.pc DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig" ) - - add_custom_target(uninstall COMMAND xargs rm -vf < ${PROJECT_BINARY_DIR}/install_manifest.txt) endif () if (PERNIX_ENABLE_DOXYGEN) @@ -87,7 +155,7 @@ if (PERNIX_ENABLE_DOXYGEN) WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} COMMENT "Generating documentation with Doxygen." ) - if (BENCHMARK_ENABLE_INSTALL AND BENCHMARK_INSTALL_DOCS) + if (PERNIX_ENABLE_INSTALL AND PERNIX_INSTALL_DOCS) install(DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/html/" DESTINATION ${CMAKE_INSTALL_DOCDIR}) endif () diff --git a/src/arm64/neon/compression.cpp b/src/arm64/neon/compression.cpp new file mode 100644 index 0000000..5968f79 --- /dev/null +++ b/src/arm64/neon/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int neon_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int neon_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int neon_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int neon_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp new file mode 100644 index 0000000..a89f763 --- /dev/null +++ b/src/arm64/neon/decompression.cpp @@ -0,0 +1,143 @@ +#include + +namespace pernix { +extern "C" { +#define PERNIX_NEON_DECOMPRESS_BLOCK_CASE(N) \ + case N: \ + return arm64::neon::neon_decompress_block(input, scale, output); + +#define PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(N) \ + case N: \ + return arm64::neon::neon_decompress_blocks(input, scale, output, blocks); + +int neon_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } +} + +int neon_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } +} + +int neon_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } +} + +int neon_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output, const uint32_t blocks) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } +} + +#undef PERNIX_NEON_DECOMPRESS_BLOCK_CASE +#undef PERNIX_NEON_DECOMPRESS_BLOCKS_CASE +} +} // namespace pernix diff --git a/src/arm64/sve/compression.cpp b/src/arm64/sve/compression.cpp new file mode 100644 index 0000000..e973183 --- /dev/null +++ b/src/arm64/sve/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int sve_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int sve_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int sve_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve/decompression.cpp b/src/arm64/sve/decompression.cpp new file mode 100644 index 0000000..c6d84be --- /dev/null +++ b/src/arm64/sve/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int sve_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int sve_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int sve_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve2/compression.cpp b/src/arm64/sve2/compression.cpp new file mode 100644 index 0000000..0a55f16 --- /dev/null +++ b/src/arm64/sve2/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve2_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int sve2_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int sve2_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int sve2_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp new file mode 100644 index 0000000..8429e97 --- /dev/null +++ b/src/arm64/sve2/decompression.cpp @@ -0,0 +1,143 @@ +#include + +namespace pernix { +extern "C" { +#define PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(N) \ + case N: \ + return arm64::sve2::sve2_decompress_block(input, scale, output); + +#define PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(N) \ + case N: \ + return arm64::sve2::sve2_decompress_blocks(input, scale, output, blocks); + +int sve2_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } +} + +int sve2_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } +} + +int sve2_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } +} + +int sve2_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output, const uint32_t blocks) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } +} + +#undef PERNIX_SVE2_DECOMPRESS_BLOCK_CASE +#undef PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE +} +} // namespace pernix diff --git a/src/pernix.cpp b/src/pernix.cpp index 87ccf9d..94d9d14 100644 --- a/src/pernix.cpp +++ b/src/pernix.cpp @@ -6,7 +6,7 @@ extern "C" { #endif // Use the best available implementation based on detected CPU features at compile time -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #ifdef PERNIX_AVX512_VBMI_ENABLED int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { return mm512_compress_block_avx512vbmi(bit_width, input, scale, output); @@ -80,6 +80,114 @@ int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ i return mm256_decompress_blocks_f64_avx2(bit_width, input, scale, output, blocks); } #endif +#elif defined(PERNIX_BACKEND_ARM64_NEON) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return neon_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return neon_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return neon_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return neon_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return neon_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return neon_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return neon_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return sve_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return sve_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE2) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve2_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve2_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve2_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve2_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return sve2_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return sve2_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} #else int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { return compress_block_fallback(bit_width, input, scale, output); @@ -121,4 +229,4 @@ int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ i #ifdef __cplusplus } } // namespace pernix -#endif // __cplusplus \ No newline at end of file +#endif // __cplusplus diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 370ec3b..b5ffe57 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,12 +2,50 @@ find_package(PkgConfig) pkg_search_module(GTEST REQUIRED gtest) include(CheckCXXCompilerFlag) +file(GLOB + PERNIX_ROOT_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp +) + file(GLOB_RECURSE - SOURCE_FILES + PERNIX_FALLBACK_TEST_SOURCES CONFIGURE_DEPENDS - *.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fallback/*.cpp ) +set(SOURCE_FILES ${PERNIX_ROOT_TEST_SOURCES} ${PERNIX_FALLBACK_TEST_SOURCES}) + +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") + file(GLOB_RECURSE + PERNIX_X86_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/x86/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_X86_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/neon/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_NEON_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve2/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE2_TEST_SOURCES}) +endif () + file(GLOB_RECURSE HEADER_FILES CONFIGURE_DEPENDS diff --git a/tests/arm64/neon/decompression_tests.cpp b/tests/arm64/neon/decompression_tests.cpp new file mode 100644 index 0000000..69229be --- /dev/null +++ b/tests/arm64/neon/decompression_tests.cpp @@ -0,0 +1,38 @@ +#include +#include + +#ifdef PERNIX_BACKEND_ARM64_NEON + +using namespace pernix::arm64::neon; + +TYPED_TEST(DecompressionTest, NeonDecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + neon_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +TYPED_TEST(DecompressionTest64, NeonDecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + neon_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +#endif \ No newline at end of file diff --git a/tests/arm64/sve/.gitkeep b/tests/arm64/sve/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/arm64/sve2/decompression_tests.cpp b/tests/arm64/sve2/decompression_tests.cpp new file mode 100644 index 0000000..82cb11f --- /dev/null +++ b/tests/arm64/sve2/decompression_tests.cpp @@ -0,0 +1,38 @@ +#include +#include + +#ifdef PERNIX_BACKEND_ARM64_SVE2 + +using namespace pernix::arm64::sve2; + +TYPED_TEST(DecompressionTest, SVE2DecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + sve2_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +TYPED_TEST(DecompressionTest64, SVE2DecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + sve2_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +#endif \ No newline at end of file diff --git a/tests/compression/fallback_compression_tests.cpp b/tests/fallback/compression_tests.cpp similarity index 97% rename from tests/compression/fallback_compression_tests.cpp rename to tests/fallback/compression_tests.cpp index 9b50109..78249d7 100644 --- a/tests/compression/fallback_compression_tests.cpp +++ b/tests/fallback/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include TYPED_TEST(CompressionTest, FallbackCompressBlock) { diff --git a/tests/decompression/fallback_decompression_tests.cpp b/tests/fallback/decompression_tests.cpp similarity index 97% rename from tests/decompression/fallback_decompression_tests.cpp rename to tests/fallback/decompression_tests.cpp index 3c5dc1f..08c26c4 100644 --- a/tests/decompression/fallback_decompression_tests.cpp +++ b/tests/fallback/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { diff --git a/tests/fallback_edge_tests.cpp b/tests/fallback/edge_tests.cpp similarity index 100% rename from tests/fallback_edge_tests.cpp rename to tests/fallback/edge_tests.cpp diff --git a/tests/include/testset.h b/tests/include/testset.h index 6535957..5ef4fa1 100644 --- a/tests/include/testset.h +++ b/tests/include/testset.h @@ -1,7 +1,7 @@ #ifndef PERNIX_TESTSET_H #define PERNIX_TESTSET_H -#include <../../include/pernix/pernix.h> +#include #include #include diff --git a/tests/compression/avx2_compression_tests.cpp b/tests/x86/avx2/compression_tests.cpp similarity index 97% rename from tests/compression/avx2_compression_tests.cpp rename to tests/x86/avx2/compression_tests.cpp index 1c2892b..bd7f683 100644 --- a/tests/compression/avx2_compression_tests.cpp +++ b/tests/x86/avx2/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX2_ENABLED diff --git a/tests/decompression/avx2_decompression_tests.cpp b/tests/x86/avx2/decompression_tests.cpp similarity index 97% rename from tests/decompression/avx2_decompression_tests.cpp rename to tests/x86/avx2/decompression_tests.cpp index e0f039f..a6fc2c5 100644 --- a/tests/decompression/avx2_decompression_tests.cpp +++ b/tests/x86/avx2/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX2_ENABLED diff --git a/tests/compression/avx512vbmi_compression_tests.cpp b/tests/x86/avx512vbmi/compression_tests.cpp similarity index 100% rename from tests/compression/avx512vbmi_compression_tests.cpp rename to tests/x86/avx512vbmi/compression_tests.cpp diff --git a/tests/decompression/avx512vbmi_decompression_tests.cpp b/tests/x86/avx512vbmi/decompression_tests.cpp similarity index 97% rename from tests/decompression/avx512vbmi_decompression_tests.cpp rename to tests/x86/avx512vbmi/decompression_tests.cpp index 446443a..f44dd8d 100644 --- a/tests/decompression/avx512vbmi_decompression_tests.cpp +++ b/tests/x86/avx512vbmi/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/compression/bmi2_compression_tests.cpp b/tests/x86/bmi2/compression_tests.cpp similarity index 97% rename from tests/compression/bmi2_compression_tests.cpp rename to tests/x86/bmi2/compression_tests.cpp index 85d3cac..b7fc2fd 100644 --- a/tests/compression/bmi2_compression_tests.cpp +++ b/tests/x86/bmi2/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_BMI2_ENABLED diff --git a/tests/decompression/bmi2_decompression_tests.cpp b/tests/x86/bmi2/decompression_tests.cpp similarity index 97% rename from tests/decompression/bmi2_decompression_tests.cpp rename to tests/x86/bmi2/decompression_tests.cpp index 11a8efb..dd7efc1 100644 --- a/tests/decompression/bmi2_decompression_tests.cpp +++ b/tests/x86/bmi2/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_BMI2_ENABLED