From bcb95b01124422ad818c70a7e3901b3e284f84ec Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 14:06:33 +0200 Subject: [PATCH 01/12] Add basic ARM64 architecture support for NEON, SVE, and SVE2. Just a skeleton, no implementations yet. --- CMakeLists.txt | 10 +- external/simde | 2 +- include/pernix/arm64/neon/compression.h | 59 +++++++ include/pernix/arm64/neon/decompression.h | 59 +++++++ include/pernix/arm64/neon/packing.h | 11 ++ include/pernix/arm64/neon/unpacking.h | 20 +++ include/pernix/arm64/sve/compression.h | 59 +++++++ include/pernix/arm64/sve/decompression.h | 59 +++++++ include/pernix/arm64/sve/packing.h | 11 ++ include/pernix/arm64/sve/unpacking.h | 11 ++ include/pernix/arm64/sve2/compression.h | 59 +++++++ include/pernix/arm64/sve2/decompression.h | 59 +++++++ include/pernix/arm64/sve2/packing.h | 11 ++ include/pernix/arm64/sve2/unpacking.h | 11 ++ include/pernix/detection.h | 19 +- include/pernix/pernix.h | 167 +++++++++++++++++- include/pernix/simd_compat.h | 8 + src/CMakeLists.txt | 48 ++++- src/arm64/neon/compression.cpp | 21 +++ src/arm64/neon/decompression.cpp | 21 +++ src/arm64/sve/compression.cpp | 21 +++ src/arm64/sve/decompression.cpp | 21 +++ src/arm64/sve2/compression.cpp | 21 +++ src/arm64/sve2/decompression.cpp | 21 +++ src/pernix.cpp | 112 +++++++++++- tests/arm64/neon/.gitkeep | 0 tests/arm64/sve/.gitkeep | 0 tests/arm64/sve2/.gitkeep | 0 .../compression_tests.cpp} | 2 +- .../decompression_tests.cpp} | 2 +- .../edge_tests.cpp} | 0 tests/include/testset.h | 2 +- .../avx2/compression_tests.cpp} | 2 +- .../avx2/decompression_tests.cpp} | 2 +- .../avx512vbmi/compression_tests.cpp} | 0 .../avx512vbmi/decompression_tests.cpp} | 2 +- .../bmi2/compression_tests.cpp} | 2 +- .../bmi2/decompression_tests.cpp} | 2 +- 38 files changed, 916 insertions(+), 21 deletions(-) create mode 100644 include/pernix/arm64/neon/compression.h create mode 100644 include/pernix/arm64/neon/decompression.h create mode 100644 include/pernix/arm64/neon/packing.h create mode 100644 include/pernix/arm64/neon/unpacking.h create mode 100644 include/pernix/arm64/sve/compression.h create mode 100644 include/pernix/arm64/sve/decompression.h create mode 100644 include/pernix/arm64/sve/packing.h create mode 100644 include/pernix/arm64/sve/unpacking.h create mode 100644 include/pernix/arm64/sve2/compression.h create mode 100644 include/pernix/arm64/sve2/decompression.h create mode 100644 include/pernix/arm64/sve2/packing.h create mode 100644 include/pernix/arm64/sve2/unpacking.h create mode 100644 src/arm64/neon/compression.cpp create mode 100644 src/arm64/neon/decompression.cpp create mode 100644 src/arm64/sve/compression.cpp create mode 100644 src/arm64/sve/decompression.cpp create mode 100644 src/arm64/sve2/compression.cpp create mode 100644 src/arm64/sve2/decompression.cpp create mode 100644 tests/arm64/neon/.gitkeep create mode 100644 tests/arm64/sve/.gitkeep create mode 100644 tests/arm64/sve2/.gitkeep rename tests/{compression/fallback_compression_tests.cpp => fallback/compression_tests.cpp} (97%) rename tests/{decompression/fallback_decompression_tests.cpp => fallback/decompression_tests.cpp} (97%) rename tests/{fallback_edge_tests.cpp => fallback/edge_tests.cpp} (100%) rename tests/{compression/avx2_compression_tests.cpp => x86/avx2/compression_tests.cpp} (97%) rename tests/{decompression/avx2_decompression_tests.cpp => x86/avx2/decompression_tests.cpp} (97%) rename tests/{compression/avx512vbmi_compression_tests.cpp => x86/avx512vbmi/compression_tests.cpp} (100%) rename tests/{decompression/avx512vbmi_decompression_tests.cpp => x86/avx512vbmi/decompression_tests.cpp} (97%) rename tests/{compression/bmi2_compression_tests.cpp => x86/bmi2/compression_tests.cpp} (97%) rename tests/{decompression/bmi2_decompression_tests.cpp => x86/bmi2/decompression_tests.cpp} (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt index edf6c78..3fa49be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,9 +13,17 @@ option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) +set(PERNIX_ARCH_BACKEND "AUTO" CACHE STRING "Pernix architecture backend (AUTO, FALLBACK, X86, ARM64_NEON, ARM64_SVE, ARM64_SVE2)") +set_property(CACHE PERNIX_ARCH_BACKEND PROPERTY STRINGS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) option(PERNIX_ENABLE_FORTRAN_BINDINGS "Build Fortran bindings for pernix" off) +string(TOUPPER "${PERNIX_ARCH_BACKEND}" PERNIX_ARCH_BACKEND) +set(PERNIX_VALID_ARCH_BACKENDS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) +if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) + message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") +endif () + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") if (PERNIX_USE_SIMDE) @@ -97,4 +105,4 @@ endif () if (PERNIX_ENABLE_TESTS) enable_testing() add_subdirectory(tests) -endif () \ No newline at end of file +endif () diff --git a/external/simde b/external/simde index 1747b24..1a1ca5e 160000 --- a/external/simde +++ b/external/simde @@ -1 +1 @@ -Subproject commit 1747b2482589fe894d49989159421da08c2a8bcd +Subproject commit 1a1ca5ee71518d8a115234dad1e2d871421953b7 diff --git a/include/pernix/arm64/neon/compression.h b/include/pernix/arm64/neon/compression.h new file mode 100644 index 0000000..65ea786 --- /dev/null +++ b/include/pernix/arm64/neon/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_NEON_COMPRESSION_H +#define PERNIX_ARM64_NEON_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool neon_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int neon_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int neon_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int neon_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int neon_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h new file mode 100644 index 0000000..0f8d79e --- /dev/null +++ b/include/pernix/arm64/neon/decompression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_NEON_DECOMPRESSION_H +#define PERNIX_ARM64_NEON_DECOMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool neon_decompression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_block(const uint8_t*, float_t, float_t*) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_block(const uint8_t*, double_t, double_t*) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int neon_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { + static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/packing.h b/include/pernix/arm64/neon/packing.h new file mode 100644 index 0000000..c1c8119 --- /dev/null +++ b/include/pernix/arm64/neon/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_NEON_PACKING_H +#define PERNIX_ARM64_NEON_PACKING_H + +#include + +namespace pernix::arm64::neon::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_PACKING_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h new file mode 100644 index 0000000..32dc5de --- /dev/null +++ b/include/pernix/arm64/neon/unpacking.h @@ -0,0 +1,20 @@ +#ifndef PERNIX_ARM64_NEON_UNPACKING_H +#define PERNIX_ARM64_NEON_UNPACKING_H + +#include + +namespace pernix::arm64::neon::internal { + +template +inline constexpr bool unpacking_unimplemented_v = false; + +namespace b64 { + +} // namespace b64 + + +} // namespace pernix::arm64::neon::internal + + + +#endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve/compression.h b/include/pernix/arm64/sve/compression.h new file mode 100644 index 0000000..f8abfed --- /dev/null +++ b/include/pernix/arm64/sve/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE_COMPRESSION_H +#define PERNIX_ARM64_SVE_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int sve_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int sve_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int sve_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE_COMPRESSION_H diff --git a/include/pernix/arm64/sve/decompression.h b/include/pernix/arm64/sve/decompression.h new file mode 100644 index 0000000..784c0c8 --- /dev/null +++ b/include/pernix/arm64/sve/decompression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE_DECOMPRESSION_H +#define PERNIX_ARM64_SVE_DECOMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve_decompression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_block(const uint8_t*, float_t, float_t*) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_block(const uint8_t*, double_t, double_t*) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { + static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +int sve_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); +int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve/packing.h b/include/pernix/arm64/sve/packing.h new file mode 100644 index 0000000..fce21ca --- /dev/null +++ b/include/pernix/arm64/sve/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE_PACKING_H +#define PERNIX_ARM64_SVE_PACKING_H + +#include + +namespace pernix::arm64::sve::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::sve::internal + +#endif // PERNIX_ARM64_SVE_PACKING_H diff --git a/include/pernix/arm64/sve/unpacking.h b/include/pernix/arm64/sve/unpacking.h new file mode 100644 index 0000000..3de49ca --- /dev/null +++ b/include/pernix/arm64/sve/unpacking.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE_UNPACKING_H +#define PERNIX_ARM64_SVE_UNPACKING_H + +#include + +namespace pernix::arm64::sve::internal { +template +inline constexpr bool unpacking_unimplemented_v = false; +} // namespace pernix::arm64::sve::internal + +#endif // PERNIX_ARM64_SVE_UNPACKING_H diff --git a/include/pernix/arm64/sve2/compression.h b/include/pernix/arm64/sve2/compression.h new file mode 100644 index 0000000..4e4627d --- /dev/null +++ b/include/pernix/arm64/sve2/compression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE2_COMPRESSION_H +#define PERNIX_ARM64_SVE2_COMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve2_compression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_block(const float_t*, float_t, uint8_t*) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_block(const double_t*, double_t, uint8_t*) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { + static_assert(internal::sve2_compression_unimplemented_v, "ARM64 SVE2 compression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve2_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); +int sve2_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); +int sve2_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, + uint32_t blocks); +int sve2_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE2_COMPRESSION_H diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h new file mode 100644 index 0000000..c27f08a --- /dev/null +++ b/include/pernix/arm64/sve2/decompression.h @@ -0,0 +1,59 @@ +#ifndef PERNIX_ARM64_SVE2_DECOMPRESSION_H +#define PERNIX_ARM64_SVE2_DECOMPRESSION_H + +#include + +#include +#include + +namespace pernix { +namespace internal { +template +inline constexpr bool sve2_decompression_unimplemented_v = false; +} // namespace internal + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const uint8_t*, float_t, float_t*) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_block(const uint8_t*, double_t, double_t*) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { + static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + return -1; +} + +#ifdef __cplusplus +extern "C" { +#endif + +int sve2_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); +int sve2_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); +int sve2_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, + uint32_t blocks); +int sve2_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); + +#ifdef __cplusplus +} +#endif +} // namespace pernix + +#endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve2/packing.h b/include/pernix/arm64/sve2/packing.h new file mode 100644 index 0000000..789b4d7 --- /dev/null +++ b/include/pernix/arm64/sve2/packing.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE2_PACKING_H +#define PERNIX_ARM64_SVE2_PACKING_H + +#include + +namespace pernix::arm64::sve2::internal { +template +inline constexpr bool packing_unimplemented_v = false; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_PACKING_H diff --git a/include/pernix/arm64/sve2/unpacking.h b/include/pernix/arm64/sve2/unpacking.h new file mode 100644 index 0000000..d654b5e --- /dev/null +++ b/include/pernix/arm64/sve2/unpacking.h @@ -0,0 +1,11 @@ +#ifndef PERNIX_ARM64_SVE2_UNPACKING_H +#define PERNIX_ARM64_SVE2_UNPACKING_H + +#include + +namespace pernix::arm64::sve2::internal { +template +inline constexpr bool unpacking_unimplemented_v = false; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/include/pernix/detection.h b/include/pernix/detection.h index edecb6c..fa9cd44 100644 --- a/include/pernix/detection.h +++ b/include/pernix/detection.h @@ -10,6 +10,19 @@ #define PERNIX_MACHINE_ID_V4 3 #define PERNIX_MACHINE_ID_V4_VBMI 4 +#if defined(PERNIX_BACKEND_ARM64_NEON) +#define PERNIX_ARM64_NEON_ENABLED +#endif + +#if defined(PERNIX_BACKEND_ARM64_SVE) +#define PERNIX_ARM64_SVE_ENABLED +#endif + +#if defined(PERNIX_BACKEND_ARM64_SVE2) +#define PERNIX_ARM64_SVE2_ENABLED +#endif + +#if defined(PERNIX_BACKEND_X86) // Map the compiler's enabled ISA set to the highest supported Pernix target level. #if (__SSE3__ && __SSE4_1__ && __SSE4_2__) #if (__AVX__ && __AVX2__ && __FMA__ && __BMI__ && __BMI2__) @@ -32,6 +45,10 @@ #define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC #endif +#else +#define PERNIX_MACHINE_ID PERNIX_MACHINE_ID_GENERIC +#endif + // Feature-selection macros consumed by the public headers. #if (PERNIX_MACHINE_ID >= PERNIX_MACHINE_ID_V2) #define PERNIX_SSE_ENABLED @@ -47,7 +64,7 @@ #define PERNIX_AVX512_VBMI_ENABLED #endif -#ifdef PERNIX_USE_SIMDE +#if defined(PERNIX_USE_SIMDE) && defined(PERNIX_BACKEND_X86) #define PERNIX_SSE_ENABLED #define PERNIX_AVX2_ENABLED #define PERNIX_BMI2_ENABLED diff --git a/include/pernix/pernix.h b/include/pernix/pernix.h index 55133bb..4998a60 100644 --- a/include/pernix/pernix.h +++ b/include/pernix/pernix.h @@ -5,7 +5,7 @@ // Include architecture-specific headers based on detected capabilities // AVX2 -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #include #include @@ -21,7 +21,22 @@ #include #endif // PERNIX_AVX512_VBMI_ENABLED -#endif // PERNIX_AVX2_ENABLED +#endif // PERNIX_BACKEND_X86 && PERNIX_AVX2_ENABLED + +#ifdef PERNIX_BACKEND_ARM64_NEON +#include +#include +#endif + +#ifdef PERNIX_BACKEND_ARM64_SVE +#include +#include +#endif + +#ifdef PERNIX_BACKEND_ARM64_SVE2 +#include +#include +#endif // Fallback (non-SIMD) implementations #include @@ -167,7 +182,7 @@ template int decompress_blocks(const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); // Use the best available implementation based on detected CPU features at compile time. -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #ifdef PERNIX_AVX512_VBMI_ENABLED template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -265,6 +280,150 @@ int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, d return mm256_decompress_blocks_avx2(input, scale, output, blocks); } #endif +#elif defined(PERNIX_BACKEND_ARM64_NEON) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return neon_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return neon_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return neon_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return neon_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return neon_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return neon_decompress_blocks(input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return sve_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return sve_decompress_blocks(input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE2) +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_block(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve2_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int compress_blocks(const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, const uint32_t blocks) { + return sve2_compress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve2_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve2_decompress_block(input, scale, output); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + return sve2_decompress_blocks(input, scale, output, blocks); +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + return sve2_decompress_blocks(input, scale, output, blocks); +} #else template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -420,4 +579,4 @@ int decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, } // namespace pernix #endif -#endif // PERNIX_H \ No newline at end of file +#endif // PERNIX_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index c95c1ee..8a66ee7 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -8,9 +8,15 @@ #define SIMDE_ENABLE_NATIVE_ALIASES #undef SIMDE_X86_AVX512FP16_NATIVE // #define SIMDE_NO_NATIVE +#if defined(PERNIX_BACKEND_X86) #include #include #include +#elif defined(PERNIX_BACKEND_ARM64_NEON) +#include +#elif defined(PERNIX_BACKEND_ARM64_SVE) || defined(PERNIX_BACKEND_ARM64_SVE2) +#include +#endif // #ifndef __mmask8 // typedef uint8_t __mmask8; @@ -27,6 +33,8 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include +#elif defined(__aarch64__) +#include #endif #ifndef __always_inline diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b7e8f23..f0c7e3e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,18 +9,46 @@ file(GLOB_RECURSE set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) -set(PERNIX_TARGET_IS_X86 OFF) -if (PERNIX_USE_SIMDE OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_TARGET_IS_X86 ON) +set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_SELECTED_ARCH_BACKEND "X86") + elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") + else () + set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") + endif () endif () +message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") -if (PERNIX_TARGET_IS_X86) +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES ./x86/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) list(APPEND PERNIX_SOURCES ${PERNIX_X86_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_SOURCES + ./arm64/neon/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/neon/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_NEON_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_SOURCES + ./arm64/sve/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_SOURCES + ./arm64/sve2/*.cpp + ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve2/*.h + ) + list(APPEND PERNIX_SOURCES ${PERNIX_ARM64_SVE2_SOURCES}) endif () add_library(pernix SHARED ${PERNIX_SOURCES}) @@ -37,6 +65,18 @@ if (PERNIX_USE_SIMDE) target_compile_definitions(pernix PUBLIC PERNIX_USE_SIMDE=1) endif () +target_compile_definitions(pernix PUBLIC "PERNIX_BACKEND_${PERNIX_SELECTED_ARCH_BACKEND}=1") + +if (PERNIX_DISABLE_BMI2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_BMI2=1) +endif () +if (PERNIX_DISABLE_AVX2) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX2=1) +endif () +if (PERNIX_DISABLE_AVX512) + target_compile_definitions(pernix PUBLIC PERNIX_DISABLE_AVX512=1) +endif () + set_target_properties(pernix PROPERTIES LINKER_LANGUAGE CXX) configure_file( diff --git a/src/arm64/neon/compression.cpp b/src/arm64/neon/compression.cpp new file mode 100644 index 0000000..5968f79 --- /dev/null +++ b/src/arm64/neon/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int neon_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int neon_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int neon_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int neon_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp new file mode 100644 index 0000000..3cb7fc7 --- /dev/null +++ b/src/arm64/neon/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int neon_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int neon_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int neon_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int neon_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve/compression.cpp b/src/arm64/sve/compression.cpp new file mode 100644 index 0000000..e973183 --- /dev/null +++ b/src/arm64/sve/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int sve_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int sve_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int sve_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve/decompression.cpp b/src/arm64/sve/decompression.cpp new file mode 100644 index 0000000..c6d84be --- /dev/null +++ b/src/arm64/sve/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int sve_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int sve_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int sve_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve2/compression.cpp b/src/arm64/sve2/compression.cpp new file mode 100644 index 0000000..0a55f16 --- /dev/null +++ b/src/arm64/sve2/compression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve2_compress_block(uint8_t, const float_t*, float_t, uint8_t*) { + return -1; +} + +int sve2_compress_block_f64(uint8_t, const double_t*, double_t, uint8_t*) { + return -1; +} + +int sve2_compress_blocks(uint8_t, const float_t*, float_t, uint8_t*, uint32_t) { + return -1; +} + +int sve2_compress_blocks_f64(uint8_t, const double_t*, double_t, uint8_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp new file mode 100644 index 0000000..8d170f6 --- /dev/null +++ b/src/arm64/sve2/decompression.cpp @@ -0,0 +1,21 @@ +#include + +namespace pernix { +extern "C" { +int sve2_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { + return -1; +} + +int sve2_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { + return -1; +} + +int sve2_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { + return -1; +} + +int sve2_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { + return -1; +} +} +} // namespace pernix diff --git a/src/pernix.cpp b/src/pernix.cpp index 87ccf9d..94d9d14 100644 --- a/src/pernix.cpp +++ b/src/pernix.cpp @@ -6,7 +6,7 @@ extern "C" { #endif // Use the best available implementation based on detected CPU features at compile time -#ifdef PERNIX_AVX2_ENABLED +#if defined(PERNIX_BACKEND_X86) && defined(PERNIX_AVX2_ENABLED) #ifdef PERNIX_AVX512_VBMI_ENABLED int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { return mm512_compress_block_avx512vbmi(bit_width, input, scale, output); @@ -80,6 +80,114 @@ int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ i return mm256_decompress_blocks_f64_avx2(bit_width, input, scale, output, blocks); } #endif +#elif defined(PERNIX_BACKEND_ARM64_NEON) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return neon_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return neon_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return neon_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return neon_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return neon_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return neon_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return neon_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return neon_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return sve_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return sve_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} +#elif defined(PERNIX_BACKEND_ARM64_SVE2) +int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block(bit_width, input, scale, output); +} + +int compress_block_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output) { + return sve2_compress_block_f64(bit_width, input, scale, output); +} + +int compress_blocks(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve2_compress_blocks(bit_width, input, scale, output, blocks); +} + +int compress_blocks_f64(const uint8_t bit_width, const double_t* __restrict__ input, const double_t scale, uint8_t* __restrict__ output, + const uint32_t blocks) { + return sve2_compress_blocks_f64(bit_width, input, scale, output, blocks); +} + +int decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + return sve2_decompress_block(bit_width, input, scale, output); +} + +int decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + return sve2_decompress_block_f64(bit_width, input, scale, output); +} + +int decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + return sve2_decompress_blocks(bit_width, input, scale, output, blocks); +} + +int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + return sve2_decompress_blocks_f64(bit_width, input, scale, output, blocks); +} #else int compress_block(const uint8_t bit_width, const float_t* __restrict__ input, const float_t scale, uint8_t* __restrict__ output) { return compress_block_fallback(bit_width, input, scale, output); @@ -121,4 +229,4 @@ int decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ i #ifdef __cplusplus } } // namespace pernix -#endif // __cplusplus \ No newline at end of file +#endif // __cplusplus diff --git a/tests/arm64/neon/.gitkeep b/tests/arm64/neon/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/arm64/sve/.gitkeep b/tests/arm64/sve/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/arm64/sve2/.gitkeep b/tests/arm64/sve2/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/compression/fallback_compression_tests.cpp b/tests/fallback/compression_tests.cpp similarity index 97% rename from tests/compression/fallback_compression_tests.cpp rename to tests/fallback/compression_tests.cpp index 9b50109..78249d7 100644 --- a/tests/compression/fallback_compression_tests.cpp +++ b/tests/fallback/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include TYPED_TEST(CompressionTest, FallbackCompressBlock) { diff --git a/tests/decompression/fallback_decompression_tests.cpp b/tests/fallback/decompression_tests.cpp similarity index 97% rename from tests/decompression/fallback_decompression_tests.cpp rename to tests/fallback/decompression_tests.cpp index 3c5dc1f..08c26c4 100644 --- a/tests/decompression/fallback_decompression_tests.cpp +++ b/tests/fallback/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include TYPED_TEST(DecompressionTest, FallbackDecompressBlock) { diff --git a/tests/fallback_edge_tests.cpp b/tests/fallback/edge_tests.cpp similarity index 100% rename from tests/fallback_edge_tests.cpp rename to tests/fallback/edge_tests.cpp diff --git a/tests/include/testset.h b/tests/include/testset.h index 6535957..5ef4fa1 100644 --- a/tests/include/testset.h +++ b/tests/include/testset.h @@ -1,7 +1,7 @@ #ifndef PERNIX_TESTSET_H #define PERNIX_TESTSET_H -#include <../../include/pernix/pernix.h> +#include #include #include diff --git a/tests/compression/avx2_compression_tests.cpp b/tests/x86/avx2/compression_tests.cpp similarity index 97% rename from tests/compression/avx2_compression_tests.cpp rename to tests/x86/avx2/compression_tests.cpp index 1c2892b..bd7f683 100644 --- a/tests/compression/avx2_compression_tests.cpp +++ b/tests/x86/avx2/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX2_ENABLED diff --git a/tests/decompression/avx2_decompression_tests.cpp b/tests/x86/avx2/decompression_tests.cpp similarity index 97% rename from tests/decompression/avx2_decompression_tests.cpp rename to tests/x86/avx2/decompression_tests.cpp index e0f039f..a6fc2c5 100644 --- a/tests/decompression/avx2_decompression_tests.cpp +++ b/tests/x86/avx2/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX2_ENABLED diff --git a/tests/compression/avx512vbmi_compression_tests.cpp b/tests/x86/avx512vbmi/compression_tests.cpp similarity index 100% rename from tests/compression/avx512vbmi_compression_tests.cpp rename to tests/x86/avx512vbmi/compression_tests.cpp diff --git a/tests/decompression/avx512vbmi_decompression_tests.cpp b/tests/x86/avx512vbmi/decompression_tests.cpp similarity index 97% rename from tests/decompression/avx512vbmi_decompression_tests.cpp rename to tests/x86/avx512vbmi/decompression_tests.cpp index 446443a..f44dd8d 100644 --- a/tests/decompression/avx512vbmi_decompression_tests.cpp +++ b/tests/x86/avx512vbmi/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_AVX512_VBMI_ENABLED diff --git a/tests/compression/bmi2_compression_tests.cpp b/tests/x86/bmi2/compression_tests.cpp similarity index 97% rename from tests/compression/bmi2_compression_tests.cpp rename to tests/x86/bmi2/compression_tests.cpp index 85d3cac..b7fc2fd 100644 --- a/tests/compression/bmi2_compression_tests.cpp +++ b/tests/x86/bmi2/compression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_BMI2_ENABLED diff --git a/tests/decompression/bmi2_decompression_tests.cpp b/tests/x86/bmi2/decompression_tests.cpp similarity index 97% rename from tests/decompression/bmi2_decompression_tests.cpp rename to tests/x86/bmi2/decompression_tests.cpp index 11a8efb..dd7efc1 100644 --- a/tests/decompression/bmi2_decompression_tests.cpp +++ b/tests/x86/bmi2/decompression_tests.cpp @@ -1,4 +1,4 @@ -#include <../../include/pernix/pernix.h> +#include #include #ifdef PERNIX_BMI2_ENABLED From 22f156223ebf5cbaee7d974e3518cc576bdaf958 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 17:09:03 +0200 Subject: [PATCH 02/12] Refactor AVX512 compression and decompression headers: improve include organization, address minor formatting inconsistencies, and move utility functions to a shared `utils` file for better maintainability. --- include/pernix/simd_compat.h | 23 +-------- include/pernix/x86/avx512vbmi/compression.h | 51 ++++++++++--------- include/pernix/x86/avx512vbmi/decompression.h | 49 +++++++++--------- include/pernix/x86/utils.h | 16 ++++++ 4 files changed, 70 insertions(+), 69 deletions(-) create mode 100644 include/pernix/x86/utils.h diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index 8a66ee7..e96ec30 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -10,8 +10,8 @@ // #define SIMDE_NO_NATIVE #if defined(PERNIX_BACKEND_X86) #include -#include #include +#include #elif defined(PERNIX_BACKEND_ARM64_NEON) #include #elif defined(PERNIX_BACKEND_ARM64_SVE) || defined(PERNIX_BACKEND_ARM64_SVE2) @@ -47,25 +47,4 @@ #endif #endif -template - requires(std::is_integral_v && sizeof(T) <= 8) -static constexpr T tail_mask(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - if (tail_bytes == 0u) { - return static_cast(0); - } - if (tail_bytes >= 64u) { - return static_cast(~uint64_t{0}); - } - const uint64_t mask = (uint64_t{1} << tail_bytes) - 1u; - return static_cast(mask); -} - -static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { - const uint32_t tail_bits = remaining_elements * bit_width; - const uint32_t tail_bytes = (tail_bits + 7u) / 8u; - return tail_bytes; -} - #endif // PERNIX_SIMD_COMPAT_H diff --git a/include/pernix/x86/avx512vbmi/compression.h b/include/pernix/x86/avx512vbmi/compression.h index bc5a375..e1cb4f0 100644 --- a/include/pernix/x86/avx512vbmi/compression.h +++ b/include/pernix/x86/avx512vbmi/compression.h @@ -1,13 +1,16 @@ #ifndef PERNIX_AVX512VBMI_COMPRESSION_H #define PERNIX_AVX512VBMI_COMPRESSION_H +#include #include -#include #include -#include +#include +#include #include +using namespace pernix::x86::internal; + namespace pernix { namespace internal { template @@ -132,7 +135,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -150,7 +153,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_2x128(converted1, converted2)); mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -162,7 +165,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -207,7 +210,7 @@ template const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_2x256(converted1, converted2)); mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -220,7 +223,7 @@ template const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(converted); mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -232,7 +235,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -269,7 +272,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(packed_input); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -281,14 +284,14 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } if constexpr (remaining_elements > 0) { const __m256 source = mm256_loadu_elements_ps(remaining_elements, input); const __m256i packed_input = mm256_clamp_signed_epi32_avx512(mm256_quantize_ps_epi32(source, scale_v256)); - const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); + const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, tail_bytes(BIT_WIDTH, remaining_elements), packed); } @@ -344,7 +347,7 @@ template mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -370,7 +373,7 @@ template mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -387,7 +390,7 @@ template mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -447,7 +450,7 @@ template mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -465,7 +468,7 @@ template mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -477,7 +480,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -517,7 +520,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(make_m512i_from_2x256(quantized1, quantized2)); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -529,7 +532,7 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -543,7 +546,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -618,7 +621,7 @@ int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const fl for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } @@ -644,13 +647,13 @@ int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const d for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } return 0; } -} // namespace pernix +} // namespace pernix #ifdef __cplusplus namespace pernix { @@ -716,7 +719,7 @@ int mm512_compress_blocks_f64_avx512vbmi(uint8_t bit_width, const double_t* __re #ifdef __cplusplus } -} // namespace pernix +} // namespace pernix #endif #endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/include/pernix/x86/avx512vbmi/decompression.h index 61280c9..6320240 100644 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ b/include/pernix/x86/avx512vbmi/decompression.h @@ -1,13 +1,16 @@ #ifndef PERNIX_AVX512VBMI_DECOMPRESSION_H #define PERNIX_AVX512VBMI_DECOMPRESSION_H +#include #include -#include #include -#include +#include +#include #include +using namespace pernix::x86::internal; + namespace pernix { namespace internal { /** @@ -58,7 +61,7 @@ template _mm512_storeu_ps(output + 48, dequantized4); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } } @@ -76,7 +79,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -90,7 +93,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -159,7 +162,7 @@ template _mm512_storeu_pd(output + 56, dequantized8); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } if constexpr (iterations_32 > 0) { @@ -185,7 +188,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -202,7 +205,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -255,7 +258,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -269,7 +272,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -282,7 +285,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -333,7 +336,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -351,7 +354,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -365,7 +368,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -406,7 +409,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -419,7 +422,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -462,7 +465,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -477,7 +480,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -493,7 +496,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -569,7 +572,7 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -596,12 +599,12 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; } -} // namespace pernix +} // namespace pernix #ifdef __cplusplus namespace pernix { @@ -666,7 +669,7 @@ int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __r #ifdef __cplusplus } -} // namespace pernix +} // namespace pernix #endif #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/include/pernix/x86/utils.h b/include/pernix/x86/utils.h new file mode 100644 index 0000000..185e0a2 --- /dev/null +++ b/include/pernix/x86/utils.h @@ -0,0 +1,16 @@ +#ifndef PERNIX_X86_UTILS_H +#define PERNIX_X86_UTILS_H + +#include + +namespace pernix::x86::internal { + +static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; +} + +} // namespace pernix::x86::internal + +#endif // PERNIX_X86_UTILS_H From 37b3a725576985aa6db47107cddb7c539fc67c13 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 17:26:44 +0200 Subject: [PATCH 03/12] Replace SIMDe submodule with FetchContent for streamlined dependency management. --- .gitmodules | 3 --- CMakeLists.txt | 16 +++++++++++++++- external/simde | 1 - 3 files changed, 15 insertions(+), 5 deletions(-) delete mode 100644 .gitmodules delete mode 160000 external/simde diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index ae538ec..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "external/simde"] - path = external/simde - url = https://github.com/simd-everywhere/simde diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fa49be..8394abf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,21 @@ endif () list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") if (PERNIX_USE_SIMDE) - add_subdirectory(external/simde EXCLUDE_FROM_ALL) + include(FetchContent) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG master + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(simde) + + if (NOT TARGET simde::simde AND DEFINED simde_SOURCE_DIR AND EXISTS "${simde_SOURCE_DIR}/simde") + add_library(simde::simde INTERFACE IMPORTED GLOBAL) + target_include_directories(simde::simde INTERFACE "${simde_SOURCE_DIR}") + endif () endif () include(CTest) diff --git a/external/simde b/external/simde deleted file mode 160000 index 1a1ca5e..0000000 --- a/external/simde +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1a1ca5ee71518d8a115234dad1e2d871421953b7 From 11779ea1ff10cd012eada302fb5f70041bf17039 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 19:07:56 +0200 Subject: [PATCH 04/12] Integrate `CONFIGURE_DEPENDS` in source file globbing, improve CMake configuration with target aliases, install rules, and LTO support, enhance SIMDe handling with flexible provider selection and bundling, and refine compiler flag management for better compatibility. --- CMakeLists.txt | 90 ++++++++++++++++++++++++------------ include/pernix/simd_compat.h | 3 ++ src/CMakeLists.txt | 50 ++++++++++++++++++-- 3 files changed, 109 insertions(+), 34 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8394abf..b8c8208 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,8 @@ option(PERNIX_DISABLE_AVX2 "Disable AVX2 optimizations" off) option(PERNIX_DISABLE_AVX512 "Disable AVX512 optimizations" off) option(PERNIX_USE_SIMDE "Use SIMDe library for portable SIMD support" off) +set(PERNIX_SIMDE_PROVIDER "AUTO" CACHE STRING "SIMDe provider when PERNIX_USE_SIMDE is enabled (AUTO, PACKAGE, FETCH)") +set_property(CACHE PERNIX_SIMDE_PROVIDER PROPERTY STRINGS AUTO PACKAGE FETCH) set(PERNIX_ARCH_BACKEND "AUTO" CACHE STRING "Pernix architecture backend (AUTO, FALLBACK, X86, ARM64_NEON, ARM64_SVE, ARM64_SVE2)") set_property(CACHE PERNIX_ARCH_BACKEND PROPERTY STRINGS AUTO FALLBACK X86 ARM64_NEON ARM64_SVE ARM64_SVE2) @@ -24,24 +26,42 @@ if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") endif () +string(TOUPPER "${PERNIX_SIMDE_PROVIDER}" PERNIX_SIMDE_PROVIDER) +set(PERNIX_VALID_SIMDE_PROVIDERS AUTO PACKAGE FETCH) +if (NOT PERNIX_SIMDE_PROVIDER IN_LIST PERNIX_VALID_SIMDE_PROVIDERS) + message(FATAL_ERROR "Unsupported PERNIX_SIMDE_PROVIDER='${PERNIX_SIMDE_PROVIDER}'. Expected one of: ${PERNIX_VALID_SIMDE_PROVIDERS}") +endif () + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") +set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL OFF) if (PERNIX_USE_SIMDE) - include(FetchContent) - FetchContent_Declare( - simde - GIT_REPOSITORY https://github.com/simd-everywhere/simde.git - GIT_TAG master - GIT_SHALLOW TRUE - GIT_PROGRESS TRUE - EXCLUDE_FROM_ALL - ) - FetchContent_MakeAvailable(simde) + if (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "PACKAGE") + find_package(simde CONFIG QUIET) + endif () + + if (NOT TARGET simde::simde AND (PERNIX_SIMDE_PROVIDER STREQUAL "AUTO" OR PERNIX_SIMDE_PROVIDER STREQUAL "FETCH")) + include(FetchContent) + set(SIMDE_TEST_CMAKE_PACKAGING OFF CACHE BOOL "Test SIMDe CMake packaging" FORCE) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG f3e8262173b7089db9a9d57a9ecef8dd07ad9c97 + GIT_PROGRESS TRUE + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(simde) + set(PERNIX_BUNDLE_SIMDE_FOR_INSTALL ON) + endif () if (NOT TARGET simde::simde AND DEFINED simde_SOURCE_DIR AND EXISTS "${simde_SOURCE_DIR}/simde") add_library(simde::simde INTERFACE IMPORTED GLOBAL) target_include_directories(simde::simde INTERFACE "${simde_SOURCE_DIR}") endif () + + if (NOT TARGET simde::simde) + message(FATAL_ERROR "PERNIX_USE_SIMDE is enabled, but simde::simde was not found. Set PERNIX_SIMDE_PROVIDER=FETCH or install SIMDe's CMake package.") + endif () endif () include(CTest) @@ -62,28 +82,42 @@ else () endif () message(STATUS "Pernix version: ${VERSION}, normalized to ${NORMALIZED_VERSION}") -set(BENCHMARK_CXX_STANDARD 20) - -set(CMAKE_CXX_STANDARD ${BENCHMARK_CXX_STANDARD}) -set(CMAKE_CXX_STANDARD_REQUIRED YES) -set(CMAKE_CXX_EXTENSIONS OFF) - -include(AddCXXCompilerFlag) if (MSVC) message(FATAL_ERROR "MSVC compiler is not supported") else () - add_cxx_compiler_flag(-Wall) - add_cxx_compiler_flag(-Wextra) - add_cxx_compiler_flag(-Wshadow) - add_cxx_compiler_flag(-Wfloat-equal) - add_cxx_compiler_flag(-Wold-style-cast) - add_cxx_compiler_flag(-Wconversion) - add_cxx_compiler_flag(-fstrict-aliasing) - add_cxx_compiler_flag(-Wno-ignored-attributes) + include(CheckCXXCompilerFlag) + set(PERNIX_PRIVATE_COMPILE_OPTIONS) + foreach (PERNIX_CXX_FLAG + -Wall + -Wextra + -Wshadow + -Wfloat-equal + -Wold-style-cast + -Wconversion + -fstrict-aliasing + -Wno-ignored-attributes + ) + string(MAKE_C_IDENTIFIER "PERNIX_HAS_CXX_FLAG_${PERNIX_CXX_FLAG}" PERNIX_CXX_FLAG_VARIABLE) + check_cxx_compiler_flag("${PERNIX_CXX_FLAG}" "${PERNIX_CXX_FLAG_VARIABLE}") + if (${PERNIX_CXX_FLAG_VARIABLE}) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "${PERNIX_CXX_FLAG}") + else () + message(STATUS "Compiler flag not supported: ${PERNIX_CXX_FLAG}") + endif () + endforeach () if (PERNIX_ENABLE_LTO) - add_cxx_compiler_flag(-flto=auto) - add_cxx_compiler_flag(-Wno-lto-type-mismatch) + include(CheckIPOSupported) + check_ipo_supported(RESULT PERNIX_IPO_SUPPORTED OUTPUT PERNIX_IPO_ERROR) + if (NOT PERNIX_IPO_SUPPORTED) + message(FATAL_ERROR "PERNIX_ENABLE_LTO is enabled, but IPO/LTO is not supported: ${PERNIX_IPO_ERROR}") + endif () + + check_cxx_compiler_flag("-Wno-lto-type-mismatch" PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + if (PERNIX_HAS_CXX_FLAG_WNO_LTO_TYPE_MISMATCH) + list(APPEND PERNIX_PRIVATE_COMPILE_OPTIONS "-Wno-lto-type-mismatch") + endif () + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") find_program(GCC_AR gcc-ar) if (GCC_AR) @@ -106,8 +140,6 @@ else () endif () endif () -include_directories(${PROJECT_SOURCE_DIR}/include) - add_subdirectory(src) if (PERNIX_ENABLE_FORTRAN_BINDINGS) diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index e96ec30..f96a84f 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -7,6 +7,9 @@ #if defined(PERNIX_USE_SIMDE) #define SIMDE_ENABLE_NATIVE_ALIASES #undef SIMDE_X86_AVX512FP16_NATIVE +#if defined(__clang__) +#define SIMDE_X86_AVX512BF16_NATIVE +#endif // #define SIMDE_NO_NATIVE #if defined(PERNIX_BACKEND_X86) #include diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f0c7e3e..8c79ab6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,9 @@ include(GNUInstallDirs) +include(CMakePackageConfigHelpers) file(GLOB_RECURSE PERNIX_COMMON_SOURCES + CONFIGURE_DEPENDS ./fallback/*.cpp ./pernix.cpp ${PROJECT_SOURCE_DIR}/include/pernix/*.h @@ -24,6 +26,7 @@ message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES + CONFIGURE_DEPENDS ./x86/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/x86/*.h ) @@ -31,6 +34,7 @@ if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") file(GLOB_RECURSE PERNIX_ARM64_NEON_SOURCES + CONFIGURE_DEPENDS ./arm64/neon/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/arm64/neon/*.h ) @@ -38,6 +42,7 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") file(GLOB_RECURSE PERNIX_ARM64_SVE_SOURCES + CONFIGURE_DEPENDS ./arm64/sve/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve/*.h ) @@ -45,6 +50,7 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") file(GLOB_RECURSE PERNIX_ARM64_SVE2_SOURCES + CONFIGURE_DEPENDS ./arm64/sve2/*.cpp ${PROJECT_SOURCE_DIR}/include/pernix/arm64/sve2/*.h ) @@ -52,13 +58,20 @@ elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") endif () add_library(pernix SHARED ${PERNIX_SOURCES}) +add_library(pernix::pernix ALIAS pernix) set_target_properties(pernix PROPERTIES OUTPUT_NAME "pernix" VERSION ${NORMALIZED_VERSION} ) +target_compile_features(pernix PUBLIC cxx_std_20) +target_compile_options(pernix PRIVATE ${PERNIX_PRIVATE_COMPILE_OPTIONS}) target_include_directories(pernix PUBLIC $ + $ ) +if (PERNIX_ENABLE_LTO) + set_target_properties(pernix PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) +endif () if (PERNIX_USE_SIMDE) target_link_libraries(pernix PUBLIC simde::simde) @@ -85,23 +98,50 @@ configure_file( if (PERNIX_ENABLE_INSTALL) install(TARGETS pernix + EXPORT pernixTargets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) - install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/pernix ${PROJECT_BINARY_DIR}/include/pernix + install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/pernix" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.*h" ) + if (PERNIX_USE_SIMDE AND PERNIX_BUNDLE_SIMDE_FOR_INSTALL) + install(DIRECTORY "${simde_SOURCE_DIR}/simde" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + FILES_MATCHING PATTERN "*.h" + ) + endif () + + configure_package_config_file( + "${PROJECT_SOURCE_DIR}/cmake/pernixConfig.cmake.in" + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + write_basic_package_version_file( + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + VERSION ${NORMALIZED_VERSION} + COMPATIBILITY SameMajorVersion + ) + install( + FILES + "${PROJECT_BINARY_DIR}/pernixConfig.cmake" + "${PROJECT_BINARY_DIR}/pernixConfigVersion.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) + install( + EXPORT pernixTargets + FILE pernixTargets.cmake + NAMESPACE pernix:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/pernix" + ) install( FILES ${PROJECT_BINARY_DIR}/pernix.pc DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig" ) - - add_custom_target(uninstall COMMAND xargs rm -vf < ${PROJECT_BINARY_DIR}/install_manifest.txt) endif () if (PERNIX_ENABLE_DOXYGEN) @@ -127,7 +167,7 @@ if (PERNIX_ENABLE_DOXYGEN) WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} COMMENT "Generating documentation with Doxygen." ) - if (BENCHMARK_ENABLE_INSTALL AND BENCHMARK_INSTALL_DOCS) + if (PERNIX_ENABLE_INSTALL AND PERNIX_INSTALL_DOCS) install(DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/html/" DESTINATION ${CMAKE_INSTALL_DOCDIR}) endif () From 17807c8ca2a62158d86e2154214e05ed0acb5736 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 20:11:11 +0200 Subject: [PATCH 05/12] Refactor inline attributes in AVX512 compression and decompression headers for consistency and clarity --- include/pernix/x86/avx512vbmi/compression.h | 62 ++++++++-------- include/pernix/x86/avx512vbmi/decompression.h | 72 +++++++++---------- include/pernix/x86/avx512vbmi/packing.h | 18 ++--- include/pernix/x86/avx512vbmi/tables.h | 66 +++++++++-------- 4 files changed, 108 insertions(+), 110 deletions(-) diff --git a/include/pernix/x86/avx512vbmi/compression.h b/include/pernix/x86/avx512vbmi/compression.h index e1cb4f0..9621e12 100644 --- a/include/pernix/x86/avx512vbmi/compression.h +++ b/include/pernix/x86/avx512vbmi/compression.h @@ -135,7 +135,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -153,7 +153,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m256i packed = m256::mm256_pack_epi8_avx512vbmi_1to8(make_m256i_from_2x128(converted1, converted2)); mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -165,7 +165,7 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri const __m128i packed = m128::mm_pack_epi8_avx512vbmi_1to8(converted); mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -183,8 +183,8 @@ __always_inline int mm512_compress_block_avx512vbmi_1to8(const float_t* __restri template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_9to16(const float_t* __restrict__ input, const float_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -210,7 +210,7 @@ template const __m512i packed = m512::mm512_pack_epi16_avx512vbmi_9to16(make_m512i_from_2x256(converted1, converted2)); mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -223,7 +223,7 @@ template const __m256i packed = m256::mm256_pack_epi16_avx512vbmi_9to16(converted); mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -235,7 +235,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -253,8 +253,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_17to24(const float_t* __restrict__ input, const float_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -272,7 +272,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(packed_input); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -284,7 +284,7 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(packed_input); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -301,8 +301,8 @@ template template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_1to8(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_64 = elements_per_block / 64; @@ -347,7 +347,7 @@ template mm512_storeu_elements_epi64(output, BIT_WIDTH, packed); - input += 64; + input += 64; output += 8 * BIT_WIDTH; } } @@ -373,7 +373,7 @@ template mm256_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } @@ -390,7 +390,7 @@ template mm_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -416,8 +416,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_9to16(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -450,7 +450,7 @@ template mm512_storeu_elements_epi32(output, BIT_WIDTH, packed); - input += 32; + input += 32; output += 4 * BIT_WIDTH; } } @@ -468,7 +468,7 @@ template mm256_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } @@ -480,7 +480,7 @@ template const __m128i packed = m128::mm_pack_epi16_avx512vbmi_9to16(converted); mm_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -498,8 +498,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, - uint8_t* __restrict__ output) { +__always_inline int mm512_compress_block_avx512vbmi_17to24(const double_t* __restrict__ input, const double_t scale, + uint8_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -520,7 +520,7 @@ template const __m512i packed = m512::mm512_pack_epi32_avx512vbmi_17to24(make_m512i_from_2x256(quantized1, quantized2)); mm512_storeu_elements_epi16(output, BIT_WIDTH, packed); - input += 16; + input += 16; output += 2 * BIT_WIDTH; } } @@ -532,7 +532,7 @@ template const __m256i packed = m256::mm256_pack_epi32_avx512vbmi_17to24(quantized); mm256_storeu_elements_epi8(output, BIT_WIDTH, packed); - input += 8; + input += 8; output += BIT_WIDTH; } @@ -546,7 +546,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Compress a single 512-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -621,7 +621,7 @@ int mm512_compress_blocks_avx512vbmi(const float_t* __restrict__ input, const fl for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } @@ -647,13 +647,13 @@ int mm512_compress_blocks_avx512vbmi(const double_t* __restrict__ input, const d for (uint32_t block = 0; block < blocks; ++block) { mm512_compress_block_avx512vbmi(block_input, scale, block_output); - block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; + block_input += (BLOCK_SIZE * 8) / BIT_WIDTH; block_output += BLOCK_SIZE; } return 0; } -} // namespace pernix +} // namespace pernix #ifdef __cplusplus namespace pernix { @@ -719,7 +719,7 @@ int mm512_compress_blocks_f64_avx512vbmi(uint8_t bit_width, const double_t* __re #ifdef __cplusplus } -} // namespace pernix +} // namespace pernix #endif #endif // PERNIX_AVX512VBMI_COMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/decompression.h b/include/pernix/x86/avx512vbmi/decompression.h index 6320240..08abc35 100644 --- a/include/pernix/x86/avx512vbmi/decompression.h +++ b/include/pernix/x86/avx512vbmi/decompression.h @@ -16,20 +16,20 @@ namespace internal { /** * @brief Dequantize sixteen integer values to floats. */ -[[gnu::always_inline]] inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { +__always_inline __m512 mm512_dequantize_epi32(const __m512i& input, const __m512& scale) { const __m512 converted = _mm512_cvtepi32_ps(input); return _mm512_mul_ps(converted, scale); } -[[gnu::always_inline]] inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { +__always_inline __m512d mm512_dequantize_epi64(const __m512i& input, const __m512d& scale) { const __m512d converted = _mm512_cvtepi64_pd(input); return _mm512_mul_pd(converted, scale); } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const uint32_t iterations_64 = elements_per_block / 64; @@ -61,7 +61,7 @@ template _mm512_storeu_ps(output + 48, dequantized4); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } } @@ -79,7 +79,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -93,7 +93,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -112,8 +112,8 @@ template template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; const uint32_t iterations_64 = elements_per_block / 64; @@ -162,7 +162,7 @@ template _mm512_storeu_pd(output + 56, dequantized8); output += 64; - input += 8 * BIT_WIDTH; + input += 8 * BIT_WIDTH; } if constexpr (iterations_32 > 0) { @@ -188,7 +188,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } if constexpr (iterations_16 > 0) { @@ -205,7 +205,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -230,8 +230,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -258,7 +258,7 @@ template _mm512_storeu_ps(output + 16, dequantized2); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -272,7 +272,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -285,7 +285,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -303,8 +303,8 @@ template template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_32 = elements_per_block / 32; @@ -336,7 +336,7 @@ template _mm512_storeu_pd(output + 24, dequantized4); output += 32; - input += 4 * BIT_WIDTH; + input += 4 * BIT_WIDTH; } } @@ -354,7 +354,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } if constexpr (iterations_8 > 0) { @@ -368,7 +368,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -387,8 +387,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -409,7 +409,7 @@ template _mm512_storeu_ps(output, dequantized); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -422,7 +422,7 @@ template _mm256_storeu_ps(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -439,8 +439,8 @@ template template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -[[gnu::always_inline]] inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int mm512_decompress_block_avx512vbmi_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -465,7 +465,7 @@ template _mm512_storeu_pd(output + 8, dequantized2); output += 16; - input += 2 * BIT_WIDTH; + input += 2 * BIT_WIDTH; } } @@ -480,7 +480,7 @@ template _mm512_storeu_pd(output, dequantized); output += 8; - input += BIT_WIDTH; + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { @@ -496,7 +496,7 @@ template return 0; } -} // namespace internal +} // namespace internal /** * @brief Decompress a single 512\-bit block using AVX-512 and AVX-512-VBMI instructions. @@ -572,7 +572,7 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -599,15 +599,15 @@ int mm512_decompress_blocks_avx512vbmi(const uint8_t* __restrict__ input, const for (uint32_t block = 0; block < blocks; ++block) { mm512_decompress_block_avx512vbmi(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; } -} // namespace pernix +} // namespace pernix -#ifdef __cplusplus namespace pernix { +#ifdef __cplusplus extern "C" { #endif /** @@ -669,7 +669,7 @@ int mm512_decompress_blocks_f64_avx512vbmi(uint8_t bit_width, const uint8_t* __r #ifdef __cplusplus } -} // namespace pernix #endif +} // namespace pernix #endif // PERNIX_AVX512VBMI_DECOMPRESSION_H diff --git a/include/pernix/x86/avx512vbmi/packing.h b/include/pernix/x86/avx512vbmi/packing.h index c9f9db9..ba3b132 100644 --- a/include/pernix/x86/avx512vbmi/packing.h +++ b/include/pernix/x86/avx512vbmi/packing.h @@ -11,7 +11,7 @@ namespace m128 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { +__always_inline __m128i mm_pack_epi16_avx512vbmi_9to16(const __m128i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -48,7 +48,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { +__always_inline __m128i mm_pack_epi8_avx512vbmi_1to8(const __m128i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -93,7 +93,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { +__always_inline __m128i mm_pack_epi32_avx512vbmi_17to24(const __m128i& input) { using tables = pack_tables_avx512_24; const __m128i maskv = _mm_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); @@ -117,7 +117,7 @@ namespace m256 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { +__always_inline __m256i mm256_pack_epi16_avx512vbmi_9to16(const __m256i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -154,7 +154,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { +__always_inline __m256i mm256_pack_epi8_avx512vbmi_1to8(const __m256i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -199,7 +199,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { +__always_inline __m256i mm256_pack_epi32_avx512vbmi_17to24(const __m256i& input) { using tables = pack_tables_avx512_24; const __m256i maskv = _mm256_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); @@ -223,7 +223,7 @@ namespace m512 { */ template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -[[gnu::always_inline]] inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { +__always_inline __m512i mm512_pack_epi16_avx512vbmi_9to16(const __m512i& input) { if constexpr (BIT_WIDTH == 16) { return input; } else { @@ -260,7 +260,7 @@ template */ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -[[gnu::always_inline]] inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { +__always_inline __m512i mm512_pack_epi8_avx512vbmi_1to8(const __m512i& input) { if constexpr (BIT_WIDTH == 8) { return input; } else { @@ -305,7 +305,7 @@ template */ template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -[[gnu::always_inline]] inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { +__always_inline __m512i mm512_pack_epi32_avx512vbmi_17to24(const __m512i& input) { using tables = pack_tables_avx512_24; const __m512i maskv = _mm512_set1_epi32(static_cast((1u << BIT_WIDTH) - 1u)); diff --git a/include/pernix/x86/avx512vbmi/tables.h b/include/pernix/x86/avx512vbmi/tables.h index 9115625..4d66727 100644 --- a/include/pernix/x86/avx512vbmi/tables.h +++ b/include/pernix/x86/avx512vbmi/tables.h @@ -9,9 +9,8 @@ #include namespace pernix::internal { - template -[[gnu::always_inline]] static inline Vec load_table(const std::array& table) { +static __always_inline Vec load_table(const std::array& table) { static_assert(sizeof(table) >= sizeof(Vec), "table is smaller than requested SIMD vector"); if constexpr (std::is_same_v) { return _mm512_load_si512(static_cast(table.data())); @@ -529,13 +528,13 @@ struct pack_tables_avx512_16 { // clang-format on } - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; template @@ -591,7 +590,7 @@ struct pack_tables_avx512_24 { return plan; } - static inline constexpr std::array word_plans = [] { + static constexpr std::array word_plans = [] { std::array plans{}; for (uint32_t i = 0; i < 16; ++i) { plans[i] = create_plan(i); @@ -600,7 +599,7 @@ struct pack_tables_avx512_24 { }(); template - [[gnu::always_inline]] static constexpr std::array make_table(Getter getter) { + static __always_inline constexpr std::array make_table(Getter getter) { std::array values{}; for (uint32_t i = 0; i < 16; ++i) { values[i] = getter(word_plans[i]); @@ -608,26 +607,26 @@ struct pack_tables_avx512_24 { return values; } - alignas(64) static inline constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); + alignas(64) static constexpr auto permute1 = make_table([](const word_plan& p) { return p.left_index1; }); - alignas(64) static inline constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); + alignas(64) static constexpr auto permute2 = make_table([](const word_plan& p) { return p.left_index2; }); - alignas(64) static inline constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); + alignas(64) static constexpr auto permute3 = make_table([](const word_plan& p) { return p.right_index; }); - alignas(64) static inline constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); + alignas(64) static constexpr auto shift1 = make_table([](const word_plan& p) { return p.left_shift1; }); - alignas(64) static inline constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); + alignas(64) static constexpr auto shift2 = make_table([](const word_plan& p) { return p.left_shift2; }); - alignas(64) static inline constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); + alignas(64) static constexpr auto shift3 = make_table([](const word_plan& p) { return p.right_shift; }); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_permute3() { return load_table(permute3); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute3() { return load_table(permute3); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } - [[gnu::always_inline]] static inline Vec get_shift3() { return load_table(shift3); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift3() { return load_table(shift3); } }; template @@ -693,11 +692,11 @@ struct unpack_tables_avx512_8 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; template @@ -768,11 +767,11 @@ struct unpack_tables_avx512_16 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute1() { return load_table(permute1); } - [[gnu::always_inline]] static inline Vec get_permute2() { return load_table(permute2); } + static __always_inline Vec get_permute1() { return load_table(permute1); } + static __always_inline Vec get_permute2() { return load_table(permute2); } - [[gnu::always_inline]] static inline Vec get_shift1() { return load_table(shift1); } - [[gnu::always_inline]] static inline Vec get_shift2() { return load_table(shift2); } + static __always_inline Vec get_shift1() { return load_table(shift1); } + static __always_inline Vec get_shift2() { return load_table(shift2); } }; template @@ -813,10 +812,9 @@ struct unpack_tables_avx512_24 { }(); public: - [[gnu::always_inline]] static inline Vec get_permute() { return load_table(permute); } - [[gnu::always_inline]] static inline Vec get_shift() { return load_table(shift); } + static __always_inline Vec get_permute() { return load_table(permute); } + static __always_inline Vec get_shift() { return load_table(shift); } }; - -} // namespace pernix::internal +} // namespace pernix::internal #endif // PERNIX_AVX512VBMI_TABLES_H From 552cea37204a02b36085becafc9da57be3a74b64 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 20:18:14 +0200 Subject: [PATCH 06/12] Refactor ARM64 NEON and SVE compression/decompression headers: organize includes, introduce namespaced utility functions, add templates for bit-width handling, and update function signatures for consistency. --- include/pernix/arm64/neon/compression.h | 124 ++++++++++++++++++---- include/pernix/arm64/neon/decompression.h | 120 +++++++++++++++++---- include/pernix/arm64/neon/packing.h | 2 - include/pernix/arm64/neon/unpacking.h | 10 -- include/pernix/arm64/sve/compression.h | 124 ++++++++++++++++++---- include/pernix/arm64/sve/decompression.h | 120 +++++++++++++++++---- include/pernix/arm64/sve/packing.h | 2 - include/pernix/arm64/sve/unpacking.h | 2 - 8 files changed, 408 insertions(+), 96 deletions(-) diff --git a/include/pernix/arm64/neon/compression.h b/include/pernix/arm64/neon/compression.h index 65ea786..6e49348 100644 --- a/include/pernix/arm64/neon/compression.h +++ b/include/pernix/arm64/neon/compression.h @@ -2,58 +2,140 @@ #define PERNIX_ARM64_NEON_COMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::neon { namespace internal { -template -inline constexpr bool neon_compression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_block(const float_t*, float_t, uint8_t*) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_block(const double_t*, double_t, uint8_t*) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +__always_inline int neon_compress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +int neon_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { - static_assert(internal::neon_compression_unimplemented_v, "ARM64 NEON compression is not implemented yet"); - return -1; +int neon_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus extern "C" { #endif -int neon_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); -int neon_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); -int neon_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, +int neon_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int neon_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int neon_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int neon_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); + +int neon_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_COMPRESSION_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 0f8d79e..44744ce 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -2,42 +2,119 @@ #define PERNIX_ARM64_NEON_DECOMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::neon { namespace internal { -template -inline constexpr bool neon_decompression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_block(const uint8_t*, float_t, float_t*) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_block(const uint8_t*, double_t, double_t*) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::neon_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::neon_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::neon_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { - static_assert(internal::neon_decompression_unimplemented_v, "ARM64 NEON decompression is not implemented yet"); - return -1; +int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + neon_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus @@ -45,15 +122,20 @@ extern "C" { #endif int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); -int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + + +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); + +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/packing.h b/include/pernix/arm64/neon/packing.h index c1c8119..538b5a8 100644 --- a/include/pernix/arm64/neon/packing.h +++ b/include/pernix/arm64/neon/packing.h @@ -4,8 +4,6 @@ #include namespace pernix::arm64::neon::internal { -template -inline constexpr bool packing_unimplemented_v = false; } // namespace pernix::arm64::neon::internal #endif // PERNIX_ARM64_NEON_PACKING_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index 32dc5de..ea22b24 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -4,17 +4,7 @@ #include namespace pernix::arm64::neon::internal { - -template -inline constexpr bool unpacking_unimplemented_v = false; - -namespace b64 { - -} // namespace b64 - - } // namespace pernix::arm64::neon::internal - #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve/compression.h b/include/pernix/arm64/sve/compression.h index f8abfed..cf83ce0 100644 --- a/include/pernix/arm64/sve/compression.h +++ b/include/pernix/arm64/sve/compression.h @@ -2,58 +2,140 @@ #define PERNIX_ARM64_SVE_COMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::sve { namespace internal { -template -inline constexpr bool sve_compression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_compress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_block(const float_t*, float_t, uint8_t*) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_block(const double_t*, double_t, uint8_t*) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +__always_inline int sve_compress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_compress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_compress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_compress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_blocks(const float_t*, float_t, uint8_t*, uint32_t) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +int sve_compress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_compress_blocks(const double_t*, double_t, uint8_t*, uint32_t) { - static_assert(internal::sve_compression_unimplemented_v, "ARM64 SVE compression is not implemented yet"); - return -1; +int sve_compress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_compress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus extern "C" { #endif -int sve_compress_block(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output); -int sve_compress_block_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output); -int sve_compress_blocks(uint8_t bit_width, const float_t* __restrict__ input, float_t scale, uint8_t* __restrict__ output, +int sve_compress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + + +int sve_compress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + +int sve_compress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int sve_compress_blocks_f64(uint8_t bit_width, const double_t* __restrict__ input, double_t scale, uint8_t* __restrict__ output, - uint32_t blocks); + +int sve_compress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::sve #endif // PERNIX_ARM64_SVE_COMPRESSION_H diff --git a/include/pernix/arm64/sve/decompression.h b/include/pernix/arm64/sve/decompression.h index 784c0c8..052a3e4 100644 --- a/include/pernix/arm64/sve/decompression.h +++ b/include/pernix/arm64/sve/decompression.h @@ -2,42 +2,119 @@ #define PERNIX_ARM64_SVE_DECOMPRESSION_H #include +#include #include #include -namespace pernix { +namespace pernix::arm64::sve { namespace internal { -template -inline constexpr bool sve_decompression_unimplemented_v = false; +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + static_assert(true, "Not yet implemented"); + return -1; +} } // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_block(const uint8_t*, float_t, float_t*) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const float_t scale, + float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_block(const uint8_t*, double_t, double_t*) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +__always_inline int sve_decompress_block(const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve_decompress_block_17to24(input, scale, output); + } + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +int sve_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { - static_assert(internal::sve_decompression_unimplemented_v, "ARM64 SVE decompression is not implemented yet"); - return -1; +int sve_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, + const uint32_t blocks) { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + return 0; } #ifdef __cplusplus @@ -45,15 +122,20 @@ extern "C" { #endif int sve_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); -int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + + +int sve_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output); + int sve_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, - uint32_t blocks); + +int sve_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, + double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::sve #endif // PERNIX_ARM64_SVE_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve/packing.h b/include/pernix/arm64/sve/packing.h index fce21ca..ab57b4f 100644 --- a/include/pernix/arm64/sve/packing.h +++ b/include/pernix/arm64/sve/packing.h @@ -4,8 +4,6 @@ #include namespace pernix::arm64::sve::internal { -template -inline constexpr bool packing_unimplemented_v = false; } // namespace pernix::arm64::sve::internal #endif // PERNIX_ARM64_SVE_PACKING_H diff --git a/include/pernix/arm64/sve/unpacking.h b/include/pernix/arm64/sve/unpacking.h index 3de49ca..2565ab7 100644 --- a/include/pernix/arm64/sve/unpacking.h +++ b/include/pernix/arm64/sve/unpacking.h @@ -4,8 +4,6 @@ #include namespace pernix::arm64::sve::internal { -template -inline constexpr bool unpacking_unimplemented_v = false; } // namespace pernix::arm64::sve::internal #endif // PERNIX_ARM64_SVE_UNPACKING_H From 23553012600b6866bdda71d4a7c551310d29ea10 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 19 May 2026 21:49:41 +0200 Subject: [PATCH 07/12] WIP: Implement NEON decompression functions with common utilities and templates --- include/pernix/arm64/neon/common.h | 106 ++++++++++++++++++++++ include/pernix/arm64/neon/decompression.h | 106 +++++++++++++++++++--- include/pernix/arm64/neon/unpacking.h | 20 +++- include/pernix/simd_compat.h | 2 +- 4 files changed, 219 insertions(+), 15 deletions(-) create mode 100644 include/pernix/arm64/neon/common.h diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h new file mode 100644 index 0000000..6efa777 --- /dev/null +++ b/include/pernix/arm64/neon/common.h @@ -0,0 +1,106 @@ +#ifndef PERNIX_ARM64_NEON_COMMON_H +#define PERNIX_ARM64_NEON_COMMON_H + +#include +#include + +namespace pernix::arm64::neon::internal { +static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { + const uint32_t tail_bits = remaining_elements * bit_width; + const uint32_t tail_bytes = (tail_bits + 7u) / 8u; + return tail_bytes; +} + +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x2_t(const int8x16_t& input) { + const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); + const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); + + return {{ + vmovl_s16(vget_low_s16(s16_lo)), + vmovl_s16(vget_high_s16(s16_lo)), + vmovl_s16(vget_low_s16(s16_hi)), + vmovl_s16(vget_high_s16(s16_hi)), + }}; +} + +__always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, const float32x4_t& scale) { + return {{ + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[2]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[3]), scale), + }}; +} + +__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_elements) { + uint8_t buffer[16] = {0}; + std::memcpy(buffer, input, tail_elements * sizeof(uint8_t)); + return vld1q_u8(buffer); +} + +__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_elements) { + uint16_t buffer[8] = {0}; + std::memcpy(buffer, input, tail_elements * sizeof(uint16_t)); + return vld1q_u16(buffer); +} + +__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_elements) { + uint32_t buffer[4] = {0}; + std::memcpy(buffer, input, tail_elements * sizeof(uint32_t)); + return vld1q_u32(buffer); +} + +__always_inline float32x4_t neon_load_tail_elements_f32(const uint8_t* input, const uint32_t tail_elements) { + float32_t buffer[4] = {0.0f}; + std::memcpy(buffer, input, tail_elements * sizeof(float32_t)); + return vld1q_f32(buffer); +} + +__always_inline float64x2_t neon_load_tail_elements_f64(const uint8_t* input, const uint32_t tail_elements) { + float64_t buffer[2] = {0.0}; + std::memcpy(buffer, input, tail_elements * sizeof(float64_t)); + return vld1q_f64(buffer); +} + +__always_inline void neon_store_tail_elements_int8(uint8_t* output, const uint8x16x4_t& data, const uint32_t tail_elements) { + uint8_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u8(buffer + i * 16, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint8_t)); +} + +__always_inline void neon_store_tail_elements_int16(uint16_t* output, const uint16x8x4_t& data, const uint32_t tail_elements) { + uint16_t buffer[8 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u16(buffer + i * 8, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint16_t)); +} + +__always_inline void neon_store_tail_elements_int32(uint32_t* output, const uint32x4x4_t& data, const uint32_t tail_elements) { + uint32_t buffer[4 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_u32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(uint32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x4_t& data, const uint32_t tail_elements) { + float32_t buffer[16 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { + float64_t buffer[8 * 4]; + for (uint32_t i = 0; i < 4; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} +} // namespace pernix::arm64::neon::internal + +#endif //PERNIX_ARM64_NEON_COMMON_H \ No newline at end of file diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 44744ce..90ed78b 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -13,48 +14,129 @@ template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; + + const float32x4_t scale_v = vdupq_n_f32(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x2_t(unpacked); + const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += 2 * BIT_WIDTH; + } + + if constexpr (remaining_elements > 0) { + const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x2_t(tail_unpacked); + const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); + } + + return 0; } template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + for (uint32_t i = 0; i < iterations_8; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + for (uint32_t i = 0; i < iterations_4; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + + for (uint32_t i = 0; i < iterations_8; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + for (uint32_t i = 0; i < iterations_4; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { - static_assert(true, "Not yet implemented"); - return -1; + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + constexpr uint32_t iterations_2 = elements_per_block / 2; + constexpr uint32_t remaining_elements = elements_per_block - iterations_2 * 2; + + for (uint32_t i = 0; i < iterations_2; ++i) { + static_assert(true, "Not yet implemented"); + } + + if constexpr (remaining_elements > 0) { + static_assert(true, "Not yet implemented"); + } } } // namespace internal diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index ea22b24..c1ae78f 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -3,8 +3,24 @@ #include -namespace pernix::arm64::neon::internal { -} // namespace pernix::arm64::neon::internal +namespace pernix::arm64::neon::internal::b128 { +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline int8x16_t neon_unpack_epi8_1to8(const int8x16_t& input) { + return input; +} +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline int16x8_t neon_unpack_epi8_9to16(const int16x8_t& input) { + return input; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline int32x4_t neon_unpack_epi8_17to24(const int32x4_t& input) { + return input; +} +} // namespace pernix::arm64::neon::internal::b128 #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index f96a84f..07d5110 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -36,7 +36,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include -#elif defined(__aarch64__) +#elif defined(__aarch64__) || defined(__arm64ec__) #include #endif From 50d14a6dcdd647bf46e6ed0bc37e03a57926c342 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Mon, 25 May 2026 18:37:03 +0200 Subject: [PATCH 08/12] WIP: implement NEON decompression functions --- tests/arm64/neon/.gitkeep => .clangd | 0 CMakeLists.txt | 12 ++ include/pernix/arm64/neon/common.h | 46 ++++- include/pernix/arm64/neon/decompression.h | 78 +++++++- include/pernix/arm64/neon/unpacking.h | 81 ++++++++- include/pernix/arm64/tables.h | 212 ++++++++++++++++++++++ src/CMakeLists.txt | 12 -- tests/CMakeLists.txt | 42 ++++- tests/arm64/neon/decompression_tests.cpp | 38 ++++ 9 files changed, 487 insertions(+), 34 deletions(-) rename tests/arm64/neon/.gitkeep => .clangd (100%) create mode 100644 include/pernix/arm64/tables.h create mode 100644 tests/arm64/neon/decompression_tests.cpp diff --git a/tests/arm64/neon/.gitkeep b/.clangd similarity index 100% rename from tests/arm64/neon/.gitkeep rename to .clangd diff --git a/CMakeLists.txt b/CMakeLists.txt index b8c8208..1cbead6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,18 @@ if (NOT PERNIX_ARCH_BACKEND IN_LIST PERNIX_VALID_ARCH_BACKENDS) message(FATAL_ERROR "Unsupported PERNIX_ARCH_BACKEND='${PERNIX_ARCH_BACKEND}'. Expected one of: ${PERNIX_VALID_ARCH_BACKENDS}") endif () +set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") + set(PERNIX_SELECTED_ARCH_BACKEND "X86") + elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") + else () + set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") + endif () +endif () +message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") + string(TOUPPER "${PERNIX_SIMDE_PROVIDER}" PERNIX_SIMDE_PROVIDER) set(PERNIX_VALID_SIMDE_PROVIDERS AUTO PACKAGE FETCH) if (NOT PERNIX_SIMDE_PROVIDER IN_LIST PERNIX_VALID_SIMDE_PROVIDERS) diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h index 6efa777..f843908 100644 --- a/include/pernix/arm64/neon/common.h +++ b/include/pernix/arm64/neon/common.h @@ -11,7 +11,7 @@ static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t rem return tail_bytes; } -__always_inline int32x4x4_t neon_convert_int8x16_int32x4x2_t(const int8x16_t& input) { +__always_inline int32x4x4_t neon_convert_int8x16_int32x4x4(const int8x16_t& input) { const int16x8_t s16_lo = vmovl_s8(vget_low_s8(input)); const int16x8_t s16_hi = vmovl_s8(vget_high_s8(input)); @@ -23,6 +23,13 @@ __always_inline int32x4x4_t neon_convert_int8x16_int32x4x2_t(const int8x16_t& in }}; } +__always_inline int32x4x2_t neon_convert_int16x8_int32x4x2(const int16x8_t& input) { + return {{ + vmovl_s16(vget_low_s16(input)), + vmovl_s16(vget_high_s16(input)), + }}; +} + __always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, const float32x4_t& scale) { return {{ vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), @@ -32,21 +39,32 @@ __always_inline float32x4x4_t neon_dequantize_epi32(const int32x4x4_t& input, co }}; } -__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_elements) { +__always_inline float32x4x2_t neon_dequantize_epi32(const int32x4x2_t& input, const float32x4_t& scale) { + return {{ + vmulq_f32(vcvtq_f32_s32(input.val[0]), scale), + vmulq_f32(vcvtq_f32_s32(input.val[1]), scale), + }}; +} + +__always_inline float32x4_t neon_dequantize_epi32(const int32x4_t& input, const float32x4_t& scale) { + return vmulq_f32(vcvtq_f32_s32(input), scale); +} + +__always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_bytes_count) { uint8_t buffer[16] = {0}; - std::memcpy(buffer, input, tail_elements * sizeof(uint8_t)); + std::memcpy(buffer, input, tail_bytes_count); return vld1q_u8(buffer); } -__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_elements) { +__always_inline uint16x8_t neon_load_tail_elements_int16(const uint8_t* input, const uint32_t tail_bytes_count) { uint16_t buffer[8] = {0}; - std::memcpy(buffer, input, tail_elements * sizeof(uint16_t)); + std::memcpy(buffer, input, tail_bytes_count); return vld1q_u16(buffer); } -__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_elements) { +__always_inline uint32x4_t neon_load_tail_elements_int32(const uint8_t* input, const uint32_t tail_bytes_count) { uint32_t buffer[4] = {0}; - std::memcpy(buffer, input, tail_elements * sizeof(uint32_t)); + std::memcpy(buffer, input, tail_bytes_count); return vld1q_u32(buffer); } @@ -94,6 +112,20 @@ __always_inline void neon_store_tail_elements_f32(float32_t* output, const float std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); } +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4x2_t& data, const uint32_t tail_elements) { + float32_t buffer[8 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f32(buffer + i * 4, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + +__always_inline void neon_store_tail_elements_f32(float32_t* output, const float32x4_t& data, const uint32_t tail_elements) { + float32_t buffer[4]; + vst1q_f32(buffer, data); + std::memcpy(output, buffer, tail_elements * sizeof(float32_t)); +} + __always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { float64_t buffer[8 * 4]; for (uint32_t i = 0; i < 4; ++i) { diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 90ed78b..6c46f1a 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -25,7 +25,7 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input const uint8x16_t source = vld1q_u8(input); const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); - const int32x4x4_t converted = neon_convert_int8x16_int32x4x2_t(unpacked); + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); const float32x4x4_t dequantized = neon_dequantize_epi32(converted, scale_v); for (uint32_t j = 0; j < 4; ++j) { @@ -40,7 +40,7 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); - const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x2_t(tail_unpacked); + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); const float32x4x4_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); @@ -58,13 +58,34 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu constexpr uint32_t iterations_8 = elements_per_block / 8; constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + const float32x4_t scale_v = vdupq_n_f32(scale); + for (uint32_t i = 0; i < iterations_8; ++i) { - static_assert(true, "Not yet implemented"); + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); + + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f32(output, dequantized.val[j]); + output += 4; + } + + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); } + + return 0; } template @@ -76,13 +97,52 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp constexpr uint32_t iterations_4 = elements_per_block / 4; constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + const float32x4_t scale_v = vdupq_n_f32(scale); + for (uint32_t i = 0; i < iterations_4; ++i) { - static_assert(true, "Not yet implemented"); + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + unpacked = b128::neon_unpack_epi8_17to24(source); + } + } + + const float32x4_t dequantized = neon_dequantize_epi32(unpacked, scale_v); + + vst1q_f32(output, dequantized); + + output += 4; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } + + const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); + + neon_store_tail_elements_f32(output, tail_dequantized, remaining_elements); } + + return 0; } template @@ -101,6 +161,8 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input if constexpr (remaining_elements > 0) { static_assert(true, "Not yet implemented"); } + + return 0; } template @@ -119,6 +181,8 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu if constexpr (remaining_elements > 0) { static_assert(true, "Not yet implemented"); } + + return 0; } template @@ -137,6 +201,8 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp if constexpr (remaining_elements > 0) { static_assert(true, "Not yet implemented"); } + + return 0; } } // namespace internal diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index c1ae78f..663eb9d 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -2,24 +2,91 @@ #define PERNIX_ARM64_NEON_UNPACKING_H #include +#include + +using namespace pernix::arm64::internal; namespace pernix::arm64::neon::internal::b128 { template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) -__always_inline int8x16_t neon_unpack_epi8_1to8(const int8x16_t& input) { - return input; +__always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { + if constexpr (BIT_WIDTH == 8) { + return vreinterpretq_s8_u8(input); + } else { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + + uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute2.data())); + + shifted = vorrq_u8(shifted, vshlq_u8(permuted2_u8, vld1q_s8(tables::shift2.data()))); + } + + constexpr int shift = 8 - BIT_WIDTH; + shifted = vshlq_n_u8(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s8(vreinterpretq_s8_u8(shifted), vdupq_n_s8(-shift)); + } else { + return vreinterpretq_s8_u8(vshlq_u8(shifted, vdupq_n_s8(-shift))); + } + } } template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline int16x8_t neon_unpack_epi8_9to16(const int16x8_t& input) { - return input; +__always_inline int16x8_t neon_unpack_epi8_9to16(const uint16x8_t& input) { + if constexpr (BIT_WIDTH == 16) { + return vreinterpretq_s16_u16(input); + } else { + using tables = table_unpacking; + + const uint8x16_t input_u8 = vreinterpretq_u8_u16(input); + + const uint8x16_t permuted1_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute1.data())); + + uint16x8_t shifted = vshlq_u16(vreinterpretq_u16_u8(permuted1_u8), vld1q_s16(tables::shift1.data())); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const uint8x16_t permuted2_u8 = vqtbl1q_u8(input_u8, vld1q_u8(tables::permute2.data())); + + const uint16x8_t shifted2 = vshlq_u16(vreinterpretq_u16_u8(permuted2_u8), vld1q_s16(tables::shift2.data())); + + shifted = vorrq_u16(shifted, shifted2); + } + + constexpr int shift = 16 - BIT_WIDTH; + shifted = vshlq_n_u16(shifted, shift); + + if constexpr (SIGN_VALUES) { + return vshlq_s16(vreinterpretq_s16_u16(shifted), vdupq_n_s16(-shift)); + } else { + return vreinterpretq_s16_u16(vshlq_u16(shifted, vdupq_n_s16(-shift))); + } + } } -template +template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline int32x4_t neon_unpack_epi8_17to24(const int32x4_t& input) { - return input; +__always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { + using tables = table_unpacking; + + const uint8x16_t input_8 = vreinterpretq_u8_u32(input); + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input_8, vld1q_u8(tables::permute.data())); + + const uint32x4_t value = vshlq_u32(vreinterpretq_u32_u8(permuted_u8), vld1q_s32(tables::shift.data())); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return vshrq_n_s32(vreinterpretq_s32_u32(vshlq_n_u32(value, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); + } } } // namespace pernix::arm64::neon::internal::b128 diff --git a/include/pernix/arm64/tables.h b/include/pernix/arm64/tables.h new file mode 100644 index 0000000..233bbc4 --- /dev/null +++ b/include/pernix/arm64/tables.h @@ -0,0 +1,212 @@ +#ifndef PERNIX_ARM64_TABLES_H +#define PERNIX_ARM64_TABLES_H + +#include +#include +#include + +namespace pernix::arm64::internal { +namespace detail { +inline constexpr std::size_t neon_vector_width = 128; +inline constexpr uint8_t inactive_lane = 0xff; + +template +constexpr bool table_indices_are_valid(const std::array& table) { + for (const uint8_t index : table) { + if (index != inactive_lane && index >= Elements) { + return false; + } + } + + return true; +} + +template +constexpr std::array make_primary_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t base = entry * lane_bytes; + + for (std::size_t lane_byte = 0; lane_byte < lane_bytes; ++lane_byte) { + table[base + lane_byte] = static_cast(first_byte + lane_byte); + } + } + + return table; +} + +template +constexpr std::array make_spill_permute() { + static_assert(LANE_BITS % 8 == 0); + + constexpr std::size_t lane_bytes = LANE_BITS / 8; + static_assert(ELEMENTS % lane_bytes == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / lane_bytes; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t first_byte = bit_start / 8; + const std::size_t bit_offset = bit_start % 8; + const std::size_t base = entry * lane_bytes; + + if (bit_offset + BIT_WIDTH > LANE_BITS) { + table[base] = static_cast(first_byte + lane_bytes); + } + } + + return table; +} + +template +constexpr std::array make_shift_right() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + + table[entry] = -static_cast(bit_offset); + } + + return table; +} + +template +constexpr std::array make_shift_left_for_spill() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = entry * BIT_WIDTH; + const std::size_t bit_offset = bit_start % 8u; + const bool spills = bit_offset + BIT_WIDTH > LANE_BITS; + + table[entry] = spills ? static_cast(LANE_BITS - bit_offset) : 0; + } + + return table; +} + +template +constexpr std::array make_contiguous_permute_32() { + static_assert(ELEMENTS % 4 == 0); + + std::array table{}; + table.fill(inactive_lane); + + for (std::size_t entry = 0; entry < ELEMENTS / 4; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + const std::size_t bit_end = bit_start + BIT_WIDTH - 1; + const std::size_t first_byte = bit_start / 8; + const std::size_t last_byte = bit_end / 8; + const std::size_t base = entry * 4; + + for (std::size_t byte = first_byte; byte <= last_byte; ++byte) { + table[base + (byte - first_byte)] = static_cast(byte); + } + } + + return table; +} + +template +constexpr std::array make_shift_right_32() { + std::array table{}; + table.fill(0); + + for (std::size_t entry = 0; entry < ELEMENTS; ++entry) { + const std::size_t bit_start = START_BIT_OFFSET + entry * BIT_WIDTH; + + table[entry] = -static_cast(bit_start % 8u); + } + + return table; +} +} // namespace detail + +template +struct table_unpacking; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8 && VECTOR_WIDTH == detail::neon_vector_width) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 8; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 16); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16 && VECTOR_WIDTH == detail::neon_vector_width) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 16; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute1 = + detail::make_primary_permute(); + alignas(64) static constexpr std::array permute2 = + detail::make_spill_permute(); + alignas(64) static constexpr std::array shift1 = + detail::make_shift_right(); + alignas(64) static constexpr std::array shift2 = + detail::make_shift_left_for_spill(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 8); + static_assert(detail::table_indices_are_valid(permute1)); + static_assert(detail::table_indices_are_valid(permute2)); +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && VECTOR_WIDTH == detail::neon_vector_width && START_BIT_OFFSET < 8) +struct table_unpacking { +private: + static constexpr std::size_t PERMUTE_ELEMENTS = VECTOR_WIDTH / 8; + static constexpr std::size_t SHIFT_ELEMENTS = VECTOR_WIDTH / 32; + +public: + static constexpr uint8_t bit_width = BIT_WIDTH; + + alignas(64) static constexpr std::array permute = + detail::make_contiguous_permute_32(); + alignas(64) static constexpr std::array shift = + detail::make_shift_right_32(); + + static_assert(PERMUTE_ELEMENTS == 16); + static_assert(SHIFT_ELEMENTS == 4); + static_assert(detail::table_indices_are_valid(permute)); +}; +} // namespace pernix::arm64::internal + +#endif // PERNIX_ARM64_TABLES_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8c79ab6..57d2f0a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,18 +11,6 @@ file(GLOB_RECURSE set(PERNIX_SOURCES ${PERNIX_COMMON_SOURCES}) -set(PERNIX_SELECTED_ARCH_BACKEND "${PERNIX_ARCH_BACKEND}") -if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "AUTO") - if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|AMD64|i[3-6]86|i686)$") - set(PERNIX_SELECTED_ARCH_BACKEND "X86") - elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") - set(PERNIX_SELECTED_ARCH_BACKEND "ARM64_NEON") - else () - set(PERNIX_SELECTED_ARCH_BACKEND "FALLBACK") - endif () -endif () -message(STATUS "Pernix architecture backend: ${PERNIX_SELECTED_ARCH_BACKEND}") - if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") file(GLOB_RECURSE PERNIX_X86_SOURCES diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 370ec3b..b5ffe57 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,12 +2,50 @@ find_package(PkgConfig) pkg_search_module(GTEST REQUIRED gtest) include(CheckCXXCompilerFlag) +file(GLOB + PERNIX_ROOT_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp +) + file(GLOB_RECURSE - SOURCE_FILES + PERNIX_FALLBACK_TEST_SOURCES CONFIGURE_DEPENDS - *.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fallback/*.cpp ) +set(SOURCE_FILES ${PERNIX_ROOT_TEST_SOURCES} ${PERNIX_FALLBACK_TEST_SOURCES}) + +if (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "X86") + file(GLOB_RECURSE + PERNIX_X86_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/x86/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_X86_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_NEON") + file(GLOB_RECURSE + PERNIX_ARM64_NEON_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/neon/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_NEON_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE") + file(GLOB_RECURSE + PERNIX_ARM64_SVE_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE_TEST_SOURCES}) +elseif (PERNIX_SELECTED_ARCH_BACKEND STREQUAL "ARM64_SVE2") + file(GLOB_RECURSE + PERNIX_ARM64_SVE2_TEST_SOURCES + CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/arm64/sve2/*.cpp + ) + list(APPEND SOURCE_FILES ${PERNIX_ARM64_SVE2_TEST_SOURCES}) +endif () + file(GLOB_RECURSE HEADER_FILES CONFIGURE_DEPENDS diff --git a/tests/arm64/neon/decompression_tests.cpp b/tests/arm64/neon/decompression_tests.cpp new file mode 100644 index 0000000..69229be --- /dev/null +++ b/tests/arm64/neon/decompression_tests.cpp @@ -0,0 +1,38 @@ +#include +#include + +#ifdef PERNIX_BACKEND_ARM64_NEON + +using namespace pernix::arm64::neon; + +TYPED_TEST(DecompressionTest, NeonDecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + neon_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +TYPED_TEST(DecompressionTest64, NeonDecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + neon_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +#endif \ No newline at end of file From 9eca7046bce1aa7be90f45420e698a4984f62af2 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Tue, 26 May 2026 00:04:55 +0200 Subject: [PATCH 09/12] WIP: implement NEON decompression functions --- include/pernix/arm64/neon/common.h | 68 ++++++++- include/pernix/arm64/neon/decompression.h | 162 +++++++++++++++------- include/pernix/arm64/neon/unpacking.h | 11 +- include/pernix/arm64/tables.h | 4 +- 4 files changed, 189 insertions(+), 56 deletions(-) diff --git a/include/pernix/arm64/neon/common.h b/include/pernix/arm64/neon/common.h index f843908..8e517fa 100644 --- a/include/pernix/arm64/neon/common.h +++ b/include/pernix/arm64/neon/common.h @@ -2,9 +2,14 @@ #define PERNIX_ARM64_NEON_COMMON_H #include + #include namespace pernix::arm64::neon::internal { +struct float64x2x8_t { + float64x2_t val[8]; +}; + static constexpr uint32_t tail_bytes(const uint8_t bit_width, const uint32_t remaining_elements) { const uint32_t tail_bits = remaining_elements * bit_width; const uint32_t tail_bytes = (tail_bits + 7u) / 8u; @@ -50,6 +55,47 @@ __always_inline float32x4_t neon_dequantize_epi32(const int32x4_t& input, const return vmulq_f32(vcvtq_f32_s32(input), scale); } +__always_inline float64x2_t neon_dequantize_epi32_f64(const int32x2_t& input, const float64x2_t& scale) { + return vmulq_f64(vcvtq_f64_s64(vmovl_s32(input)), scale); +} + +__always_inline float64x2x2_t neon_dequantize_epi32_f64(const int32x4_t& input, const float64x2_t& scale) { + return {{ + neon_dequantize_epi32_f64(vget_low_s32(input), scale), + neon_dequantize_epi32_f64(vget_high_s32(input), scale), + }}; +} + +__always_inline float64x2x4_t neon_dequantize_epi32_f64(const int32x4x2_t& input, const float64x2_t& scale) { + const float64x2x2_t dequantized_low = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized_high = neon_dequantize_epi32_f64(input.val[1], scale); + + return {{ + dequantized_low.val[0], + dequantized_low.val[1], + dequantized_high.val[0], + dequantized_high.val[1], + }}; +} + +__always_inline float64x2x8_t neon_dequantize_epi32_f64(const int32x4x4_t& input, const float64x2_t& scale) { + const float64x2x2_t dequantized0 = neon_dequantize_epi32_f64(input.val[0], scale); + const float64x2x2_t dequantized1 = neon_dequantize_epi32_f64(input.val[1], scale); + const float64x2x2_t dequantized2 = neon_dequantize_epi32_f64(input.val[2], scale); + const float64x2x2_t dequantized3 = neon_dequantize_epi32_f64(input.val[3], scale); + + return {{ + dequantized0.val[0], + dequantized0.val[1], + dequantized1.val[0], + dequantized1.val[1], + dequantized2.val[0], + dequantized2.val[1], + dequantized3.val[0], + dequantized3.val[1], + }}; +} + __always_inline uint8x16_t neon_load_tail_elements_int8(const uint8_t* input, const uint32_t tail_bytes_count) { uint8_t buffer[16] = {0}; std::memcpy(buffer, input, tail_bytes_count); @@ -127,12 +173,28 @@ __always_inline void neon_store_tail_elements_f32(float32_t* output, const float } __always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x4_t& data, const uint32_t tail_elements) { - float64_t buffer[8 * 4]; + float64_t buffer[2 * 4]; for (uint32_t i = 0; i < 4; ++i) { vst1q_f64(buffer + i * 2, data.val[i]); } std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); } -} // namespace pernix::arm64::neon::internal -#endif //PERNIX_ARM64_NEON_COMMON_H \ No newline at end of file +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x2_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 2]; + for (uint32_t i = 0; i < 2; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} + +__always_inline void neon_store_tail_elements_f64(float64_t* output, const float64x2x8_t& data, const uint32_t tail_elements) { + float64_t buffer[2 * 8]; + for (uint32_t i = 0; i < 8; ++i) { + vst1q_f64(buffer + i * 2, data.val[i]); + } + std::memcpy(output, buffer, tail_elements * sizeof(float64_t)); +} +} // namespace pernix::arm64::neon::internal + +#endif // PERNIX_ARM64_NEON_COMMON_H diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index 6c46f1a..cfc051e 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -1,9 +1,9 @@ #ifndef PERNIX_ARM64_NEON_DECOMPRESSION_H #define PERNIX_ARM64_NEON_DECOMPRESSION_H -#include -#include #include +#include +#include #include #include @@ -12,8 +12,7 @@ namespace pernix::arm64::neon { namespace internal { template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_16 = elements_per_block / 16; @@ -51,8 +50,7 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_8 = elements_per_block / 8; @@ -90,8 +88,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; constexpr uint32_t iterations_4 = elements_per_block / 4; @@ -123,12 +120,12 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp } if constexpr (remaining_elements > 0) { - constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; - const uint8_t* tail_input = input + tail_bit_start / 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; - const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); int32x4_t tail_unpacked; if constexpr (tail_bit_offset == 0) { @@ -147,19 +144,37 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_8 = elements_per_block / 8; - constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; + constexpr uint32_t iterations_16 = elements_per_block / 16; + constexpr uint32_t remaining_elements = elements_per_block - iterations_16 * 16; - for (uint32_t i = 0; i < iterations_8; ++i) { - static_assert(true, "Not yet implemented"); + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_16; ++i) { + const uint8x16_t source = vld1q_u8(input); + const int8x16_t unpacked = b128::neon_unpack_epi8_1to8(source); + + const int32x4x4_t converted = neon_convert_int8x16_int32x4x4(unpacked); + const float64x2x8_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 8; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += 2 * BIT_WIDTH; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + const uint8x16_t tail_source = neon_load_tail_elements_int8(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int8x16_t tail_unpacked = b128::neon_unpack_epi8_1to8(tail_source); + + const int32x4x4_t tail_converted = neon_convert_int8x16_int32x4x4(tail_unpacked); + const float64x2x8_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); } return 0; @@ -167,19 +182,37 @@ __always_inline int neon_decompress_block_1to8(const uint8_t* __restrict__ input template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_4 = elements_per_block / 4; - constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + constexpr uint32_t iterations_8 = elements_per_block / 8; + constexpr uint32_t remaining_elements = elements_per_block - iterations_8 * 8; - for (uint32_t i = 0; i < iterations_4; ++i) { - static_assert(true, "Not yet implemented"); + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_8; ++i) { + const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); + const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + + const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); + const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); + + for (uint32_t j = 0; j < 4; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } + + input += BIT_WIDTH; } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); + const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + + const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); + const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); } return 0; @@ -187,29 +220,65 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; - constexpr uint32_t iterations_2 = elements_per_block / 2; - constexpr uint32_t remaining_elements = elements_per_block - iterations_2 * 2; + constexpr uint32_t iterations_4 = elements_per_block / 4; + constexpr uint32_t remaining_elements = elements_per_block - iterations_4 * 4; + + const float64x2_t scale_v = vdupq_n_f64(scale); + + for (uint32_t i = 0; i < iterations_4; ++i) { + const uint32_t group_bit_start = i * 4u * BIT_WIDTH; + const uint8_t* group_input = input + group_bit_start / 8u; + const uint32x4_t source = vld1q_u32(reinterpret_cast(group_input)); + + int32x4_t unpacked; + if constexpr (BIT_WIDTH % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + if (i % 2 == 0) { + unpacked = b128::neon_unpack_epi8_17to24(source); + } else { + unpacked = b128::neon_unpack_epi8_17to24(source); + } + } + + const float64x2x2_t dequantized = neon_dequantize_epi32_f64(unpacked, scale_v); - for (uint32_t i = 0; i < iterations_2; ++i) { - static_assert(true, "Not yet implemented"); + for (uint32_t j = 0; j < 2; ++j) { + vst1q_f64(output, dequantized.val[j]); + output += 2; + } } if constexpr (remaining_elements > 0) { - static_assert(true, "Not yet implemented"); + constexpr uint32_t tail_bit_start = iterations_4 * 4u * BIT_WIDTH; + constexpr uint32_t tail_bit_offset = tail_bit_start % 8u; + const uint8_t* tail_input = input + tail_bit_start / 8u; + + constexpr uint32_t tail_bytes_count = (tail_bit_offset + remaining_elements * BIT_WIDTH + 7u) / 8u; + const uint32x4_t tail_source = neon_load_tail_elements_int32(tail_input, tail_bytes_count); + + int32x4_t tail_unpacked; + if constexpr (tail_bit_offset == 0) { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } else { + tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + } + + const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); + + neon_store_tail_elements_f64(output, tail_dequantized, remaining_elements); } return 0; } -} // namespace internal +} // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, - float_t* __restrict__ output) { +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return internal::neon_decompress_block_1to8(input, scale, output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { @@ -222,8 +291,7 @@ __always_inline int neon_decompress_block(const uint8_t* __restrict__ input, con template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, - double_t* __restrict__ output) { +__always_inline int neon_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { return internal::neon_decompress_block_1to8(input, scale, output); } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { @@ -236,14 +304,13 @@ __always_inline int neon_decompress_block(const uint8_t* __restrict__ input, con template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, - const uint32_t blocks) { +int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { const uint8_t* block_input = input; float_t* block_output = output; for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -252,14 +319,13 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, - const uint32_t blocks) { +int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { const uint8_t* block_input = input; double_t* block_output = output; for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; @@ -271,19 +337,17 @@ extern "C" { int neon_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); - -int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output); +int neon_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); int neon_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); -int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, - double_t* __restrict__ output, uint32_t blocks); +int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, + uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix::arm64::neon +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index 663eb9d..b1f8585 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -1,8 +1,8 @@ #ifndef PERNIX_ARM64_NEON_UNPACKING_H #define PERNIX_ARM64_NEON_UNPACKING_H -#include #include +#include using namespace pernix::arm64::internal; @@ -12,6 +12,13 @@ template __always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { if constexpr (BIT_WIDTH == 8) { return vreinterpretq_s8_u8(input); + } else if constexpr (BIT_WIDTH == 1) { + using tables = table_unpacking; + + const uint8x16_t permuted_u8 = vqtbl1q_u8(input, vld1q_u8(tables::permute1.data())); + const uint8x16_t shifted = vshlq_u8(permuted_u8, vld1q_s8(tables::shift1.data())); + + return vreinterpretq_s8_u8(vandq_u8(shifted, vdupq_n_u8(1))); } else { using tables = table_unpacking; @@ -88,6 +95,6 @@ __always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); } } -} // namespace pernix::arm64::neon::internal::b128 +} // namespace pernix::arm64::neon::internal::b128 #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/tables.h b/include/pernix/arm64/tables.h index 233bbc4..60e1dfe 100644 --- a/include/pernix/arm64/tables.h +++ b/include/pernix/arm64/tables.h @@ -134,7 +134,7 @@ constexpr std::array make_shift_right_32() { return table; } -} // namespace detail +} // namespace detail template struct table_unpacking; @@ -207,6 +207,6 @@ struct table_unpacking { static_assert(SHIFT_ELEMENTS == 4); static_assert(detail::table_indices_are_valid(permute)); }; -} // namespace pernix::arm64::internal +} // namespace pernix::arm64::internal #endif // PERNIX_ARM64_TABLES_H From af90ba8ac32167d689c220feb0dca14314affdf8 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Thu, 28 May 2026 13:22:46 +0200 Subject: [PATCH 10/12] WIP: Implement SVE2 decompression functions and update related headers for ARM64 --- include/pernix/arm64/neon/decompression.h | 36 +-- include/pernix/arm64/{ => neon}/tables.h | 26 +- include/pernix/arm64/neon/unpacking.h | 10 +- include/pernix/arm64/sve2/decompression.h | 346 ++++++++++++++++++++-- include/pernix/arm64/sve2/tables.h | 143 +++++++++ include/pernix/arm64/sve2/unpacking.h | 87 +++++- include/pernix/simd_compat.h | 5 + src/arm64/neon/decompression.cpp | 140 ++++++++- src/arm64/sve2/decompression.cpp | 76 ++++- tests/arm64/sve2/.gitkeep | 0 tests/arm64/sve2/decompression_tests.cpp | 38 +++ 11 files changed, 833 insertions(+), 74 deletions(-) rename include/pernix/arm64/{ => neon}/tables.h (94%) create mode 100644 include/pernix/arm64/sve2/tables.h delete mode 100644 tests/arm64/sve2/.gitkeep create mode 100644 tests/arm64/sve2/decompression_tests.cpp diff --git a/include/pernix/arm64/neon/decompression.h b/include/pernix/arm64/neon/decompression.h index cfc051e..583948f 100644 --- a/include/pernix/arm64/neon/decompression.h +++ b/include/pernix/arm64/neon/decompression.h @@ -60,7 +60,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu for (uint32_t i = 0; i < iterations_8; ++i) { const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); - const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); const float32x4x2_t dequantized = neon_dequantize_epi32(converted, scale_v); @@ -75,7 +75,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu if constexpr (remaining_elements > 0) { const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); const float32x4x2_t tail_dequantized = neon_dequantize_epi32(tail_converted, scale_v); @@ -103,12 +103,12 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t unpacked; if constexpr (BIT_WIDTH % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { if (i % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } } @@ -129,9 +129,9 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t tail_unpacked; if constexpr (tail_bit_offset == 0) { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } else { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } const float32x4_t tail_dequantized = neon_dequantize_epi32(tail_unpacked, scale_v); @@ -192,7 +192,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu for (uint32_t i = 0; i < iterations_8; ++i) { const uint16x8_t source = vld1q_u16(reinterpret_cast(input)); - const int16x8_t unpacked = b128::neon_unpack_epi8_9to16(source); + const int16x8_t unpacked = b128::neon_unpack_epi16_9to16(source); const int32x4x2_t converted = neon_convert_int16x8_int32x4x2(unpacked); const float64x2x4_t dequantized = neon_dequantize_epi32_f64(converted, scale_v); @@ -207,7 +207,7 @@ __always_inline int neon_decompress_block_9to16(const uint8_t* __restrict__ inpu if constexpr (remaining_elements > 0) { const uint16x8_t tail_source = neon_load_tail_elements_int16(input, tail_bytes(BIT_WIDTH, remaining_elements)); - const int16x8_t tail_unpacked = b128::neon_unpack_epi8_9to16(tail_source); + const int16x8_t tail_unpacked = b128::neon_unpack_epi16_9to16(tail_source); const int32x4x2_t tail_converted = neon_convert_int16x8_int32x4x2(tail_unpacked); const float64x2x4_t tail_dequantized = neon_dequantize_epi32_f64(tail_converted, scale_v); @@ -235,12 +235,12 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t unpacked; if constexpr (BIT_WIDTH % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { if (i % 2 == 0) { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } else { - unpacked = b128::neon_unpack_epi8_17to24(source); + unpacked = b128::neon_unpack_epi32_17to24(source); } } @@ -262,9 +262,9 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp int32x4_t tail_unpacked; if constexpr (tail_bit_offset == 0) { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } else { - tail_unpacked = b128::neon_unpack_epi8_17to24(tail_source); + tail_unpacked = b128::neon_unpack_epi32_17to24(tail_source); } const float64x2x2_t tail_dequantized = neon_dequantize_epi32_f64(tail_unpacked, scale_v); @@ -274,7 +274,7 @@ __always_inline int neon_decompress_block_17to24(const uint8_t* __restrict__ inp return 0; } -} // namespace internal +} // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) @@ -310,7 +310,7 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const float_t scal for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } @@ -325,7 +325,7 @@ int neon_decompress_blocks(const uint8_t* __restrict__ input, const double_t sca for (uint32_t block = 0; block < blocks; ++block) { neon_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; + block_input += BLOCK_SIZE; block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } return 0; @@ -348,6 +348,6 @@ int neon_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ in #ifdef __cplusplus } #endif -} // namespace pernix::arm64::neon +} // namespace pernix::arm64::neon #endif // PERNIX_ARM64_NEON_DECOMPRESSION_H diff --git a/include/pernix/arm64/tables.h b/include/pernix/arm64/neon/tables.h similarity index 94% rename from include/pernix/arm64/tables.h rename to include/pernix/arm64/neon/tables.h index 60e1dfe..d085551 100644 --- a/include/pernix/arm64/tables.h +++ b/include/pernix/arm64/neon/tables.h @@ -1,24 +1,21 @@ -#ifndef PERNIX_ARM64_TABLES_H -#define PERNIX_ARM64_TABLES_H +#ifndef PERNIX_ARM64_NEON_TABLES_H +#define PERNIX_ARM64_NEON_TABLES_H +#include #include #include #include -namespace pernix::arm64::internal { +namespace pernix::arm64::neon::internal { namespace detail { inline constexpr std::size_t neon_vector_width = 128; inline constexpr uint8_t inactive_lane = 0xff; template constexpr bool table_indices_are_valid(const std::array& table) { - for (const uint8_t index : table) { - if (index != inactive_lane && index >= Elements) { - return false; - } - } - - return true; + return std::ranges::all_of(table, [](const uint8_t index) { + return index == inactive_lane || index < Elements; + }); } template @@ -134,7 +131,7 @@ constexpr std::array make_shift_right_32() { return table; } -} // namespace detail +} // namespace detail template struct table_unpacking; @@ -207,6 +204,9 @@ struct table_unpacking { static_assert(SHIFT_ELEMENTS == 4); static_assert(detail::table_indices_are_valid(permute)); }; -} // namespace pernix::arm64::internal -#endif // PERNIX_ARM64_TABLES_H +template +struct table_packing; +} // namespace pernix::arm64::internal + +#endif // PERNIX_ARM64_NEON_TABLES_H diff --git a/include/pernix/arm64/neon/unpacking.h b/include/pernix/arm64/neon/unpacking.h index b1f8585..6ac0e20 100644 --- a/include/pernix/arm64/neon/unpacking.h +++ b/include/pernix/arm64/neon/unpacking.h @@ -1,10 +1,10 @@ #ifndef PERNIX_ARM64_NEON_UNPACKING_H #define PERNIX_ARM64_NEON_UNPACKING_H -#include +#include #include -using namespace pernix::arm64::internal; +using namespace pernix::arm64::neon::internal; namespace pernix::arm64::neon::internal::b128 { template @@ -45,7 +45,7 @@ __always_inline int8x16_t neon_unpack_epi8_1to8(const uint8x16_t& input) { template requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) -__always_inline int16x8_t neon_unpack_epi8_9to16(const uint16x8_t& input) { +__always_inline int16x8_t neon_unpack_epi16_9to16(const uint16x8_t& input) { if constexpr (BIT_WIDTH == 16) { return vreinterpretq_s16_u16(input); } else { @@ -78,7 +78,7 @@ __always_inline int16x8_t neon_unpack_epi8_9to16(const uint16x8_t& input) { template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) -__always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { +__always_inline int32x4_t neon_unpack_epi32_17to24(const uint32x4_t& input) { using tables = table_unpacking; const uint8x16_t input_8 = vreinterpretq_u8_u32(input); @@ -95,6 +95,6 @@ __always_inline int32x4_t neon_unpack_epi8_17to24(const uint32x4_t& input) { return vreinterpretq_s32_u32(vandq_u32(value, vdupq_n_u32(mask))); } } -} // namespace pernix::arm64::neon::internal::b128 +} // namespace pernix::arm64::neon::internal::b128 #endif // PERNIX_ARM64_NEON_UNPACKING_H diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h index c27f08a..198a8fc 100644 --- a/include/pernix/arm64/sve2/decompression.h +++ b/include/pernix/arm64/sve2/decompression.h @@ -1,43 +1,352 @@ #ifndef PERNIX_ARM64_SVE2_DECOMPRESSION_H #define PERNIX_ARM64_SVE2_DECOMPRESSION_H +#include +#include #include +#include #include #include +#include -namespace pernix { +namespace pernix::arm64::sve2 { namespace internal { -template -inline constexpr bool sve2_decompression_unimplemented_v = false; -} // namespace internal +template +[[nodiscard]] __always_inline constexpr uint32_t packed_bytes(const uint32_t elements) { + return (elements * BIT_WIDTH + 7) / 8; +} + +[[nodiscard]] __always_inline svuint8_t sve2_load_packed_bytes(const uint8_t* __restrict__ input, const uint32_t bytes) { + const svbool_t pg = svwhilelt_b8(uint64_t{0}, static_cast(bytes)); + return svld1_u8(pg, input); +} + +template +__always_inline void sve2_store_dequantized_i8_f32(svint8_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sb_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1ub_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i8_f64(svint8_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntb()); + + svst1_s8(svptrue_b8(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i16_f32(svint16_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + alignas(64) std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + uint32_t offset = 0; + while (offset < count) { + const svbool_t pg = svwhilelt_b32(static_cast(offset), static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + const svint32_t widened = svld1sh_s32(pg, temp.data() + offset); + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, widened), scale_v); + } else { + const svuint32_t widened = svld1uh_u32(pg, reinterpret_cast(temp.data() + offset)); + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, widened), scale_v); + } + + svst1_f32(pg, output + offset, dequantized); + + offset += static_cast(svcntw()); + } +} + +template +__always_inline void sve2_store_dequantized_i16_f64(svint16_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcnth()); + + svst1_s16(svptrue_b16(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + +template +__always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfloat32_t scale_v, float_t* __restrict__ output, + const uint32_t count) { + const svbool_t pg = svwhilelt_b32(uint64_t{0}, static_cast(count)); + + svfloat32_t dequantized; + if constexpr (SIGN_VALUES) { + dequantized = svmul_f32_x(pg, svcvt_f32_s32_x(pg, values), scale_v); + } else { + dequantized = svmul_f32_x(pg, svcvt_f32_u32_x(pg, svreinterpret_u32_s32(values)), scale_v); + } + + svst1_f32(pg, output, dequantized); +} template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_block(const uint8_t*, float_t, float_t*) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const svfloat32_t scale_v = svdup_n_f32(scale); + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f32(unpacked, scale_v, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__, double_t, double_t* __restrict__) { return -1; } template - requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_block(const uint8_t*, double_t, double_t*) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcnth()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint16_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint16_t spill_shift = svdup_n_u16(0); + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint16_t unpacked = + sve2_unpack_epi16_9to16(svreinterpret_u16_u8(source), permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i16_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__, double_t, double_t* __restrict__) { return -1; } +} // namespace internal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_blocks(const uint8_t*, float_t, float_t*, uint32_t) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); - return -1; +int sve2_decompress_block(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(input, scale, output); + } } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -int sve2_decompress_blocks(const uint8_t*, double_t, double_t*, uint32_t) { - static_assert(internal::sve2_decompression_unimplemented_v, "ARM64 SVE2 decompression is not implemented yet"); - return -1; +int sve2_decompress_block(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + if constexpr (BIT_WIDTH >= 1 && BIT_WIDTH <= 8) { + return internal::sve2_decompress_block_1to8(input, scale, output); + } else if constexpr (BIT_WIDTH >= 9 && BIT_WIDTH <= 16) { + return internal::sve2_decompress_block_9to16(input, scale, output); + } else if constexpr (BIT_WIDTH >= 17 && BIT_WIDTH <= 24) { + return internal::sve2_decompress_block_17to24(input, scale, output); + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { + if constexpr (BIT_WIDTH > 8) { + return -1; + } else { + const uint8_t* block_input = input; + float_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } +} + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) +int sve2_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { + if constexpr (BIT_WIDTH > 8) { + return -1; + } else { + const uint8_t* block_input = input; + double_t* block_output = output; + + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; + } + + return 0; + } } #ifdef __cplusplus @@ -45,15 +354,18 @@ extern "C" { #endif int sve2_decompress_block(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output); + int sve2_decompress_block_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output); + int sve2_decompress_blocks(uint8_t bit_width, const uint8_t* __restrict__ input, float_t scale, float_t* __restrict__ output, uint32_t blocks); + int sve2_decompress_blocks_f64(uint8_t bit_width, const uint8_t* __restrict__ input, double_t scale, double_t* __restrict__ output, uint32_t blocks); #ifdef __cplusplus } #endif -} // namespace pernix +} // namespace pernix::arm64::sve2 #endif // PERNIX_ARM64_SVE2_DECOMPRESSION_H diff --git a/include/pernix/arm64/sve2/tables.h b/include/pernix/arm64/sve2/tables.h new file mode 100644 index 0000000..813e042 --- /dev/null +++ b/include/pernix/arm64/sve2/tables.h @@ -0,0 +1,143 @@ +#ifndef PERNIX_ARM64_SVE2_TABLES_H +#define PERNIX_ARM64_SVE2_TABLES_H + +#include + +#include +#include + +namespace pernix::arm64::sve2::internal { +template +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svbool_t pg_b8() { return svptrue_b8(); } + + static svbool_t pg_b16() { return svptrue_b16(); } + + static svbool_t pg_b32() { return svptrue_b32(); } +}; + +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) / 8u); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t spill_permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) / 8u + 1u); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t shift() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) % 8u); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t spill_shift() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast(8u - ((lane * BIT_WIDTH) % 8u)); + } + + return svld1_u8(svptrue_b8(), table.data()); + } +}; + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t element = lane / 2u; + const uint32_t byte = lane % 2u; + const uint32_t first = (element * BIT_WIDTH) / 8u; + + table[lane] = static_cast(first + byte); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint8_t spill_permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t element = lane / 2u; + const uint32_t byte = lane % 2u; + const uint32_t first = (element * BIT_WIDTH) / 8u; + + table[lane] = static_cast(first + 2u + byte); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint16_t shift() { + std::vector table(svcnth()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = static_cast((lane * BIT_WIDTH) % 8u); + } + + return svld1_u16(svptrue_b16(), table.data()); + } + + static svuint16_t spill_shift() { + std::vector table(svcnth()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t bit_offset = (lane * BIT_WIDTH) % 8u; + table[lane] = bit_offset + BIT_WIDTH > 16u ? static_cast(16u - bit_offset) : uint16_t{16}; + } + + return svld1_u16(svptrue_b16(), table.data()); + } +}; + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24 && START_BIT_OFFSET < 8) +struct table_unpacking { + static constexpr uint8_t bit_width = BIT_WIDTH; + + static svuint8_t permute() { + std::vector table(svcntb()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + const uint32_t element = lane / 4u; + const uint32_t byte = lane % 4u; + const uint32_t first = (START_BIT_OFFSET + element * BIT_WIDTH) / 8u; + + table[lane] = static_cast(first + byte); + } + + return svld1_u8(svptrue_b8(), table.data()); + } + + static svuint32_t shift() { + std::vector table(svcntw()); + for (uint32_t lane = 0; lane < table.size(); ++lane) { + table[lane] = (START_BIT_OFFSET + lane * BIT_WIDTH) % 8u; + } + + return svld1_u32(svptrue_b32(), table.data()); + } +}; +} // namespace pernix::arm64::sve2::internal + +#endif // PERNIX_ARM64_SVE2_TABLES_H diff --git a/include/pernix/arm64/sve2/unpacking.h b/include/pernix/arm64/sve2/unpacking.h index d654b5e..326901f 100644 --- a/include/pernix/arm64/sve2/unpacking.h +++ b/include/pernix/arm64/sve2/unpacking.h @@ -3,9 +3,90 @@ #include +#include "tables.h" + namespace pernix::arm64::sve2::internal { -template -inline constexpr bool unpacking_unimplemented_v = false; -} // namespace pernix::arm64::sve2::internal +template + requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) +__always_inline svint8_t sve2_unpack_epi8_1to8(const svuint8_t input, const svuint8_t permute, const svuint8_t shift, + const svuint8_t spill_permute, const svuint8_t spill_shift) { + if constexpr (BIT_WIDTH == 8) { + return svreinterpret_s8(input); + } else { + const svbool_t pg = svptrue_b8(); + + const svuint8_t permuted = svtbl_u8(input, permute); + svuint8_t unpacked = svlsr_u8_x(pg, permuted, shift); + + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + const svuint8_t spill_permuted_values = svtbl_u8(input, spill_permute); + const svuint8_t spill_shifted = svlsl_u8_x(pg, spill_permuted_values, spill_shift); + unpacked = svorr_u8_x(pg, unpacked, spill_shifted); + } + + if constexpr (BIT_WIDTH == 1) { + unpacked = svand_n_u8_x(pg, unpacked, 1); + return svreinterpret_s8(unpacked); + } else { + constexpr int sign_shift = 8 - BIT_WIDTH; + + unpacked = svlsl_n_u8_x(pg, unpacked, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s8_x(pg, svreinterpret_s8_u8(unpacked), sign_shift); + } else { + return svreinterpret_s8_u8(svlsr_n_u8_x(pg, unpacked, sign_shift)); + } + } + } +} + +template + requires(BIT_WIDTH >= 9 && BIT_WIDTH <= 16) +__always_inline svint16_t sve2_unpack_epi16_9to16(const svuint16_t input, const svuint8_t permute, const svuint16_t shift, + const svuint8_t spill_permute, const svuint16_t spill_shift) { + if constexpr (BIT_WIDTH == 16) { + return svreinterpret_s16(input); + } else { + const svbool_t pg = svptrue_b16(); + + const svuint8_t permuted = svtbl_u8(svreinterpret_u8_u16(input), permute); + svuint16_t shifted = svlsr_u16_x(pg, svreinterpret_u16_u8(permuted), shift); + + if constexpr (BIT_WIDTH == 11 || BIT_WIDTH == 13 || BIT_WIDTH == 14 || BIT_WIDTH == 15) { + const svuint8_t spill_permuted_values = svtbl_u8(svreinterpret_u8_u16(input), spill_permute); + const svuint16_t spill_shifted = svlsl_u16_x(pg, svreinterpret_u16_u8(spill_permuted_values), spill_shift); + shifted = svorr_u16_x(pg, shifted, spill_shifted); + } + + constexpr int sign_shift = 16 - BIT_WIDTH; + shifted = svlsl_n_u16_x(pg, shifted, sign_shift); + + if constexpr (SIGN_VALUES) { + return svasr_n_s16_x(pg, svreinterpret_s16_u16(shifted), sign_shift); + } else { + return svreinterpret_s16_u16(svlsr_n_u16_x(pg, shifted, sign_shift)); + } + } +} + +template + requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) +__always_inline svint32_t sve2_unpack_epi32_17to24(const svuint8_t input) { + using table = table_unpacking; + + const svbool_t pg = svptrue_b32(); + const svuint8_t permuted = svtbl_u8(input, table::permute()); + const svuint32_t unpacked = svlsr_u32_x(pg, svreinterpret_u32_u8(permuted), table::shift()); + + if constexpr (SIGN_VALUES) { + constexpr int sign_shift = 32 - BIT_WIDTH; + return svasr_n_s32_x(pg, svreinterpret_s32_u32(svlsl_n_u32_x(pg, unpacked, sign_shift)), sign_shift); + } else { + constexpr uint32_t mask = (uint32_t{1} << BIT_WIDTH) - 1u; + return svreinterpret_s32_u32(svand_n_u32_x(pg, unpacked, mask)); + } +} +} // namespace pernix::arm64::sve2::internal #endif // PERNIX_ARM64_SVE2_UNPACKING_H diff --git a/include/pernix/simd_compat.h b/include/pernix/simd_compat.h index 07d5110..509eb13 100644 --- a/include/pernix/simd_compat.h +++ b/include/pernix/simd_compat.h @@ -37,8 +37,13 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) #include #elif defined(__aarch64__) || defined(__arm64ec__) +#ifdef __ARM_FEATURE_SVE +#include +#endif +#ifdef __ARM_NEON #include #endif +#endif #ifndef __always_inline #if defined(__GNUC__) || defined(__clang__) diff --git a/src/arm64/neon/decompression.cpp b/src/arm64/neon/decompression.cpp index 3cb7fc7..a89f763 100644 --- a/src/arm64/neon/decompression.cpp +++ b/src/arm64/neon/decompression.cpp @@ -2,20 +2,142 @@ namespace pernix { extern "C" { -int neon_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { - return -1; +#define PERNIX_NEON_DECOMPRESS_BLOCK_CASE(N) \ + case N: \ + return arm64::neon::neon_decompress_block(input, scale, output); + +#define PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(N) \ + case N: \ + return arm64::neon::neon_decompress_blocks(input, scale, output, blocks); + +int neon_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } } -int neon_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { - return -1; +int neon_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCK_CASE(24) + default: + return -1; + } } -int neon_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { - return -1; +int neon_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } } -int neon_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { - return -1; +int neon_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output, const uint32_t blocks) { + switch (bit_width) { + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_NEON_DECOMPRESS_BLOCKS_CASE(24) + default: + return -1; + } } + +#undef PERNIX_NEON_DECOMPRESS_BLOCK_CASE +#undef PERNIX_NEON_DECOMPRESS_BLOCKS_CASE } -} // namespace pernix +} // namespace pernix diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp index 8d170f6..66bbd97 100644 --- a/src/arm64/sve2/decompression.cpp +++ b/src/arm64/sve2/decompression.cpp @@ -2,20 +2,78 @@ namespace pernix { extern "C" { -int sve2_decompress_block(uint8_t, const uint8_t*, float_t, float_t*) { - return -1; +#define PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(N) \ + case N: \ + return arm64::sve2::sve2_decompress_block(input, scale, output); + +#define PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(N) \ + case N: \ + return arm64::sve2::sve2_decompress_blocks(input, scale, output, blocks); + +int sve2_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + default: + return -1; + } } -int sve2_decompress_block_f64(uint8_t, const uint8_t*, double_t, double_t*) { - return -1; +int sve2_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + default: + return -1; + } } -int sve2_decompress_blocks(uint8_t, const uint8_t*, float_t, float_t*, uint32_t) { - return -1; +int sve2_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, + const uint32_t blocks) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + default: + return -1; + } } -int sve2_decompress_blocks_f64(uint8_t, const uint8_t*, double_t, double_t*, uint32_t) { - return -1; +int sve2_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restrict__ input, const double_t scale, + double_t* __restrict__ output, const uint32_t blocks) { + switch (bit_width) { + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(1) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(2) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(3) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(4) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(5) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + default: + return -1; + } } + +#undef PERNIX_SVE2_DECOMPRESS_BLOCK_CASE +#undef PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE } -} // namespace pernix +} // namespace pernix diff --git a/tests/arm64/sve2/.gitkeep b/tests/arm64/sve2/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/tests/arm64/sve2/decompression_tests.cpp b/tests/arm64/sve2/decompression_tests.cpp new file mode 100644 index 0000000..82cb11f --- /dev/null +++ b/tests/arm64/sve2/decompression_tests.cpp @@ -0,0 +1,38 @@ +#include +#include + +#ifdef PERNIX_BACKEND_ARM64_SVE2 + +using namespace pernix::arm64::sve2; + +TYPED_TEST(DecompressionTest, SVE2DecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + sve2_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +TYPED_TEST(DecompressionTest64, SVE2DecompressBlock) { + std::vector > decompressedData(this->testSet.numberOfBlocks); + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + decompressedData[block].resize(this->testSet.elementsPerBlock); + + sve2_decompress_block( + this->testSet.getCompressedData()[block].data(), this->testSet.getScales()[block], decompressedData[block].data()); + } + + for (uint32_t block = 0; block < this->testSet.numberOfBlocks; block++) { + expectDecompressedBlockNearSource(*this, decompressedData[block], block); + } +} + +#endif \ No newline at end of file From 0f0c1782b66c6687948714a6a53f6183b633d972 Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Thu, 28 May 2026 21:32:10 +0200 Subject: [PATCH 11/12] WIP: Implement SVE2 decompression functions --- include/pernix/arm64/sve2/decompression.h | 122 +++++++++++++++++----- 1 file changed, 94 insertions(+), 28 deletions(-) diff --git a/include/pernix/arm64/sve2/decompression.h b/include/pernix/arm64/sve2/decompression.h index 198a8fc..2128ff2 100644 --- a/include/pernix/arm64/sve2/decompression.h +++ b/include/pernix/arm64/sve2/decompression.h @@ -122,6 +122,22 @@ __always_inline void sve2_store_dequantized_i32_f32(svint32_t values, const svfl svst1_f32(pg, output, dequantized); } +template +__always_inline void sve2_store_dequantized_i32_f64(svint32_t values, const double_t scale, double_t* __restrict__ output, + const uint32_t count) { + std::vector temp(svcntw()); + + svst1_s32(svptrue_b32(), temp.data(), values); + + for (uint32_t i = 0; i < count; ++i) { + if constexpr (SIGN_VALUES) { + output[i] = static_cast(temp[i]) * scale; + } else { + output[i] = static_cast(static_cast(temp[i])) * scale; + } + } +} + template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) __always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output) { @@ -238,8 +254,39 @@ __always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ inp template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 8) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__, double_t, double_t* __restrict__) { - return -1; +__always_inline int sve2_decompress_block_1to8(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntb()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + const table_unpacking table; + const svuint8_t permute = table.permute(); + const svuint8_t shift = table.shift(); + svuint8_t spill_permute = svdup_n_u8(0); + svuint8_t spill_shift = svdup_n_u8(0); + if constexpr (BIT_WIDTH == 3 || BIT_WIDTH == 5 || BIT_WIDTH == 6 || BIT_WIDTH == 7) { + spill_permute = table.spill_permute(); + spill_shift = table.spill_shift(); + } + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint32_t bytes = packed_bytes(count); + const uint8_t* chunk_input = input + input_bit_offset / 8; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + const svint8_t unpacked = sve2_unpack_epi8_1to8(source, permute, shift, spill_permute, spill_shift); + + sve2_store_dequantized_i8_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; } template @@ -282,8 +329,35 @@ __always_inline int sve2_decompress_block_9to16(const uint8_t* __restrict__ inpu template requires(BIT_WIDTH >= 17 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) -__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__, double_t, double_t* __restrict__) { - return -1; +__always_inline int sve2_decompress_block_17to24(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output) { + constexpr uint32_t elements_per_block = (BLOCK_SIZE * 8) / BIT_WIDTH; + + const auto lanes = static_cast(svcntw()); + uint32_t input_bit_offset = 0; + uint32_t processed_elements = 0; + + while (processed_elements < elements_per_block) { + const uint32_t count = std::min(elements_per_block - processed_elements, lanes); + + const uint8_t* chunk_input = input + input_bit_offset / 8; + const uint32_t bit_offset = input_bit_offset % 8; + const uint32_t bytes = (bit_offset + count * BIT_WIDTH + 7u) / 8u; + + const svuint8_t source = sve2_load_packed_bytes(chunk_input, bytes); + svint32_t unpacked; + if (bit_offset == 0) { + unpacked = sve2_unpack_epi32_17to24(source); + } else { + unpacked = sve2_unpack_epi32_17to24(source); + } + + sve2_store_dequantized_i32_f64(unpacked, scale, output + processed_elements, count); + + processed_elements += count; + input_bit_offset += count * BIT_WIDTH; + } + + return 0; } } // namespace internal @@ -314,39 +388,31 @@ int sve2_decompress_block(const uint8_t* __restrict__ input, const double_t scal template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_decompress_blocks(const uint8_t* __restrict__ input, const float_t scale, float_t* __restrict__ output, const uint32_t blocks) { - if constexpr (BIT_WIDTH > 8) { - return -1; - } else { - const uint8_t* block_input = input; - float_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve2_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + const uint8_t* block_input = input; + float_t* block_output = output; - return 0; + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } + + return 0; } template requires(BIT_WIDTH >= 1 && BIT_WIDTH <= 24) && (BLOCK_SIZE % 32 == 0) int sve2_decompress_blocks(const uint8_t* __restrict__ input, const double_t scale, double_t* __restrict__ output, const uint32_t blocks) { - if constexpr (BIT_WIDTH > 8) { - return -1; - } else { - const uint8_t* block_input = input; - double_t* block_output = output; - - for (uint32_t block = 0; block < blocks; ++block) { - sve2_decompress_block(block_input, scale, block_output); - block_input += BLOCK_SIZE; - block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; - } + const uint8_t* block_input = input; + double_t* block_output = output; - return 0; + for (uint32_t block = 0; block < blocks; ++block) { + sve2_decompress_block(block_input, scale, block_output); + block_input += BLOCK_SIZE; + block_output += (BLOCK_SIZE * 8) / BIT_WIDTH; } + + return 0; } #ifdef __cplusplus From 95c6475d8092f96dd3417c376bcda8e03ba9e94e Mon Sep 17 00:00:00 2001 From: Felix Sternberg Date: Thu, 28 May 2026 21:59:45 +0200 Subject: [PATCH 12/12] WIP: Extend SVE2 decompression functions --- include/pernix/arm64/sve2/tables.h | 115 +++++++++++------------------ src/arm64/sve2/decompression.cpp | 64 ++++++++++++++++ 2 files changed, 109 insertions(+), 70 deletions(-) diff --git a/include/pernix/arm64/sve2/tables.h b/include/pernix/arm64/sve2/tables.h index 813e042..897fa9b 100644 --- a/include/pernix/arm64/sve2/tables.h +++ b/include/pernix/arm64/sve2/tables.h @@ -4,7 +4,6 @@ #include #include -#include namespace pernix::arm64::sve2::internal { template @@ -24,39 +23,23 @@ struct table_unpacking { static constexpr uint8_t bit_width = BIT_WIDTH; static svuint8_t permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) / 8u); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svlsr_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 3); } static svuint8_t spill_permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) / 8u + 1u); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 1); } static svuint8_t shift() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) % 8u); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svand_n_u8_x(pg, svindex_u8(0, BIT_WIDTH), 7); } static svuint8_t spill_shift() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast(8u - ((lane * BIT_WIDTH) % 8u)); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svsub_u8_x(pg, svdup_n_u8(8), shift()); } }; @@ -66,48 +49,39 @@ struct table_unpacking { static constexpr uint8_t bit_width = BIT_WIDTH; static svuint8_t permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t element = lane / 2u; - const uint32_t byte = lane % 2u; - const uint32_t first = (element * BIT_WIDTH) / 8u; - - table[lane] = static_cast(first + byte); + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 1); + const svuint8_t byte = svand_n_u8_x(pg, lane, 1); + + svuint8_t first; + if constexpr (BIT_WIDTH == 16) { + first = svlsl_n_u8_x(pg, elem, 1); + } else { + constexpr uint8_t extra_bits = BIT_WIDTH - 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low = svlsr_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), 3); + first = svadd_u8_x(pg, elem, svadd_u8_x(pg, high, low)); } - return svld1_u8(svptrue_b8(), table.data()); + return svadd_u8_x(pg, first, byte); } static svuint8_t spill_permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t element = lane / 2u; - const uint32_t byte = lane % 2u; - const uint32_t first = (element * BIT_WIDTH) / 8u; - - table[lane] = static_cast(first + 2u + byte); - } - - return svld1_u8(svptrue_b8(), table.data()); + const svbool_t pg = svptrue_b8(); + return svadd_n_u8_x(pg, permute(), 2); } static svuint16_t shift() { - std::vector table(svcnth()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = static_cast((lane * BIT_WIDTH) % 8u); - } - - return svld1_u16(svptrue_b16(), table.data()); + const svbool_t pg = svptrue_b16(); + return svand_n_u16_x(pg, svmul_n_u16_x(pg, svindex_u16(0, 1), BIT_WIDTH), 7); } static svuint16_t spill_shift() { - std::vector table(svcnth()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t bit_offset = (lane * BIT_WIDTH) % 8u; - table[lane] = bit_offset + BIT_WIDTH > 16u ? static_cast(16u - bit_offset) : uint16_t{16}; - } - - return svld1_u16(svptrue_b16(), table.data()); + const svbool_t pg = svptrue_b16(); + const svuint16_t bit_shift = shift(); + const svuint16_t spill = svsub_u16_x(pg, svdup_n_u16(16), bit_shift); + return svsel_u16(svcmpgt_n_u16(pg, bit_shift, 16u - BIT_WIDTH), spill, svdup_n_u16(16)); } }; @@ -117,25 +91,26 @@ struct table_unpacking { static constexpr uint8_t bit_width = BIT_WIDTH; static svuint8_t permute() { - std::vector table(svcntb()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - const uint32_t element = lane / 4u; - const uint32_t byte = lane % 4u; - const uint32_t first = (START_BIT_OFFSET + element * BIT_WIDTH) / 8u; - - table[lane] = static_cast(first + byte); + const svbool_t pg = svptrue_b8(); + const svuint8_t lane = svindex_u8(0, 1); + const svuint8_t elem = svlsr_n_u8_x(pg, lane, 2); + const svuint8_t byte = svand_n_u8_x(pg, lane, 3); + + svuint8_t first = svmul_n_u8_x(pg, elem, BIT_WIDTH / 8u); + if constexpr (BIT_WIDTH % 8u != 0) { + constexpr uint8_t extra_bits = BIT_WIDTH % 8u; + const svuint8_t high = svmul_n_u8_x(pg, svlsr_n_u8_x(pg, elem, 3), extra_bits); + const svuint8_t low_bits = + svadd_n_u8_x(pg, svmul_n_u8_x(pg, svand_n_u8_x(pg, elem, 7), extra_bits), START_BIT_OFFSET); + first = svadd_u8_x(pg, first, svadd_u8_x(pg, high, svlsr_n_u8_x(pg, low_bits, 3))); } - return svld1_u8(svptrue_b8(), table.data()); + return svadd_u8_x(pg, first, byte); } static svuint32_t shift() { - std::vector table(svcntw()); - for (uint32_t lane = 0; lane < table.size(); ++lane) { - table[lane] = (START_BIT_OFFSET + lane * BIT_WIDTH) % 8u; - } - - return svld1_u32(svptrue_b32(), table.data()); + const svbool_t pg = svptrue_b32(); + return svand_n_u32_x(pg, svadd_n_u32_x(pg, svmul_n_u32_x(pg, svindex_u32(0, 1), BIT_WIDTH), START_BIT_OFFSET), 7); } }; } // namespace pernix::arm64::sve2::internal diff --git a/src/arm64/sve2/decompression.cpp b/src/arm64/sve2/decompression.cpp index 66bbd97..8429e97 100644 --- a/src/arm64/sve2/decompression.cpp +++ b/src/arm64/sve2/decompression.cpp @@ -20,6 +20,22 @@ int sve2_decompress_block(const uint8_t bit_width, const uint8_t* __restrict__ i PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) default: return -1; } @@ -36,6 +52,22 @@ int sve2_decompress_block_f64(const uint8_t bit_width, const uint8_t* __restrict PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCK_CASE(24) default: return -1; } @@ -52,6 +84,22 @@ int sve2_decompress_blocks(const uint8_t bit_width, const uint8_t* __restrict__ PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) default: return -1; } @@ -68,6 +116,22 @@ int sve2_decompress_blocks_f64(const uint8_t bit_width, const uint8_t* __restric PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(6) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(7) PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(8) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(9) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(10) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(11) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(12) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(13) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(14) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(15) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(16) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(17) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(18) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(19) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(20) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(21) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(22) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(23) + PERNIX_SVE2_DECOMPRESS_BLOCKS_CASE(24) default: return -1; }