diff --git a/primitive_data/extensions/simd/intel/avx512.yaml b/primitive_data/extensions/simd/intel/avx512.yaml index b5e71e8b..e735624c 100644 --- a/primitive_data/extensions/simd/intel/avx512.yaml +++ b/primitive_data/extensions/simd/intel/avx512.yaml @@ -3,7 +3,22 @@ description: "Definition of the SIMD TargetExtension avx512." vendor: "intel" extension_name: "avx512" #todo: these are not all lscpu flags but a sample -lscpu_flags: ["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512er", "avx512vl", "avx512vbmi", "avx512vbmi2"] +lscpu_flags: + - "avx512f" + - "avx512bw" + - "avx512cd" + - "avx512dq" + - "avx512er" + - "avx512vl" + - "avx512vbmi" + - "avx512_vbmi2" + - "avx512_fp16" + - "avx512_vpopcntdq" +arch_flags: + "avx512_fp16": "avx512fp16" + "avx512_vpopcntdq": "avx512vpopcntdq" + "avx512_vbmi2": "avx512vbmi2" + includes: ['"immintrin.h"'] simdT_name: "avx512" simdT_default_size_in_bits: 512 diff --git a/primitive_data/primitives/binary.yaml b/primitive_data/primitives/binary.yaml index 787253cf..576b96b2 100644 --- a/primitive_data/primitives/binary.yaml +++ b/primitive_data/primitives/binary.yaml @@ -447,20 +447,52 @@ definitions: ctype: ["uint32_t", "uint64_t", "int32_t", "int64_t"] lscpu_flags: ["avx512f"] implementation: "return _mm512_slli_epi{{ intrin_tp[ctype][1] }}(data, shift);" + - target_extension: "avx512" + ctype: ["uint8_t", "int8_t", "uint16_t", "int16_t"] + lscpu_flags: ["avx512f"] + implementation: | + return _mm512_and_si512( + _mm512_slli_epi32(data, shift), + _mm512_set1_epi{{ intrin_tp[ctype][1] }}((-1ul)<(data); + auto shift_arr = tsl::to_array(shift); + for (size_t i = 0; i < Vec::vector_element_count(); ++i) { + data_arr[i] <<= shift_arr[i]; + } + return tsl::load(data_arr.data()); + /*auto const zero = _mm_setzero_si128(); + auto const mask0 = _mm_insert_epi64(zero, 0, 0xFFFFFFFFFFFFFFFF); + auto const mask1 = _mm_slli_si128(mask0, 8); + auto const shift1 = _mm_srli_si128(shift, 8); + auto const r0 = + _mm_and_si128( + _mm_sll_epi64(data, shift), + mask0 + ); + auto const r1 = + _mm_and_si128( + _mm_sll_epi64(data, shift1), + mask1 + ); + return _mm_or_si128(r0, r1);*/ - target_extension: "sse" ctype: ["uint32_t", "uint64_t", "int32_t", "int64_t"] - lscpu_flags: ["avx2"] + lscpu_flags: ["sse2", "avx2"] + implementation: "return _mm_sllv_epi{{ intrin_tp[ctype][1] }}(data, shift);" + - target_extension: "sse" + ctype: ["uint16_t", "int16_t"] + lscpu_flags: ["sse2", "avx512bw", "avx512vl"] implementation: "return _mm_sllv_epi{{ intrin_tp[ctype][1] }}(data, shift);" #ARM - NEON - target_extension: "neon" @@ -652,7 +877,99 @@ definitions: } - target_extension: "avx512" ctype: ["uint16_t", "int16_t"] - lscpu_flags: ["avx512bw"] + lscpu_flags: ["avx512f"] + includes: [""] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m512i const result1_16 = _mm512_and_si512(_mm512_srai_epi32(data, shift), _mm512_set1_epi32(0xFFFF0000)); + __m512i const result0_16 = _mm512_srli_epi32(_mm512_srai_epi32(_mm512_slli_epi32(data, 16), shift), 16); + return _mm512_or_si512(result0_16, result1_16); + } else { + __m512i const result0_16 = _mm512_srli_epi32(_mm512_and_si512(data, _mm512_set1_epi32(0xFFFF)), shift); + __m512i const result1_16 = _mm512_and_si512(_mm512_srli_epi32(data, shift), _mm512_set1_epi32(0xFFFF0000)); + return _mm512_or_si512(result0_16, result1_16); + } + alternative: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + // extract the lower and upper 256 bit + __m256i const lower256_16 = _mm512_extracti32x8_epi32(data, 0); + __m256i const upper256_16 = _mm512_extracti32x8_epi32(data, 1); + // Sign extend packed 16-bit integers in a to packed 32-bit integers + __m512i const lower256_32 = _mm512_cvtepi16_epi32(lower256_16); + __m512i const upper256_32 = _mm512_cvtepi16_epi32(upper256_16); + // Shift the 32-bit integers to the right while preserving the sign + __m512i const lower256_32_shifted = _mm512_srai_epi32(lower256_32, shift); + __m512i const upper256_32_shifted = _mm512_srai_epi32(upper256_32, shift); + // Convert packed 32-bit integers in a to packed 16-bit integers with truncation + __m256i const lower256_16_shifted = _mm512_cvtepi32_epi16(lower256_32_shifted); + __m256i const upper256_16_shifted = _mm512_cvtepi32_epi16(upper256_32_shifted); + // Merge results and return + __m512i const result = _mm512_zextsi256_si512(lower256_16_shifted); + return _mm512_inserti32x8(result, upper256_16_shifted, 1); + } else { + return _mm512_and_si512( + _mm512_srli_epi32(data, shift), + _mm512_set1_epi16((static_cast(-1))>>shift) + ); + } + - target_extension: "avx512" + ctype: ["uint8_t", "int8_t"] + lscpu_flags: ["avx512f"] + includes: [""] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m512i const valid_bits = _mm512_set1_epi32(0xFF); + __m512i const result0_8 = _mm512_srli_epi32(_mm512_srai_epi32(_mm512_slli_epi32(data, 24), shift), 24); + __m512i const result1_8 = _mm512_and_si512(_mm512_srli_epi32(_mm512_srai_epi32(_mm512_slli_epi32(data, 16), shift), 16), _mm512_slli_epi32(valid_bits, 8)); + __m512i const result2_8 = _mm512_and_si512(_mm512_srli_epi32(_mm512_srai_epi32(_mm512_slli_epi32(data, 8), shift), 8), _mm512_slli_epi32(valid_bits, 16)); + __m512i const result3_8 = _mm512_and_si512(_mm512_srai_epi32(data, shift), _mm512_slli_epi32(valid_bits, 24)); + return _mm512_or_si512(_mm512_or_si512(result0_8, result1_8), _mm512_or_si512(result2_8, result3_8)); + } else { + return _mm512_and_si512( + _mm512_srli_epi32(data, shift), + _mm512_set1_epi8((static_cast(-1))>>shift) + ); + } + alternative: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + // extract the lower and upper 128 bit + __m128i const data0_128_8 = _mm512_extracti32x4_epi32(data, 0); + __m128i const data1_128_8 = _mm512_extracti32x4_epi32(data, 1); + __m128i const data2_128_8 = _mm512_extracti32x4_epi32(data, 2); + __m128i const data3_128_8 = _mm512_extracti32x4_epi32(data, 3); + + // Sign extend packed 8-bit integers in a to packed 32-bit integers + __m512i const data0_128_32 = _mm512_cvtepi8_epi32(data0_128_8); + __m512i const data1_128_32 = _mm512_cvtepi8_epi32(data1_128_8); + __m512i const data2_128_32 = _mm512_cvtepi8_epi32(data2_128_8); + __m512i const data3_128_32 = _mm512_cvtepi8_epi32(data3_128_8); + + // Shift the 32-bit integers to the right while preserving the sign + __m512i const data0_128_32_shifted = _mm512_srai_epi32(data0_128_32, shift); + __m512i const data1_128_32_shifted = _mm512_srai_epi32(data1_128_32, shift); + __m512i const data2_128_32_shifted = _mm512_srai_epi32(data2_128_32, shift); + __m512i const data3_128_32_shifted = _mm512_srai_epi32(data3_128_32, shift); + + // Convert packed 32-bit integers in a to packed 16-bit integers with truncation + __m128i const data0_128_8_shifted = _mm512_cvtepi32_epi8(data0_128_32_shifted); + __m128i const data1_128_8_shifted = _mm512_cvtepi32_epi8(data1_128_32_shifted); + __m128i const data2_128_8_shifted = _mm512_cvtepi32_epi8(data2_128_32_shifted); + __m128i const data3_128_8_shifted = _mm512_cvtepi32_epi8(data3_128_32_shifted); + + // Merge partial results + __m512i result = _mm512_zextsi128_si512(data0_128_8_shifted); + result = _mm512_inserti32x4(result, data1_128_8_shifted, 1); + result = _mm512_inserti32x4(result, data2_128_8_shifted, 2); + return _mm512_inserti32x4(result, data3_128_8_shifted, 3); + } else { + return _mm512_and_si512( + _mm512_srli_epi32(data, shift), + _mm512_set1_epi8((static_cast(-1))>>shift) + ); + } + - target_extension: "avx512" + ctype: ["uint16_t", "int16_t"] + lscpu_flags: ["avx512f", "avx512bw"] implementation: | if constexpr ((std::is_signed_v) && (PreserveSign)) { return _mm512_srai_epi{{ intrin_tp[ctype][1] }}(data, shift); @@ -661,50 +978,70 @@ definitions: } #Intel - AVX2 - target_extension: "avx2" - ctype: ["uint16_t", "uint32_t", "int16_t", "int32_t", "uint64_t"] + ctype: ["int64_t"] lscpu_flags: ["avx2"] implementation: | - if constexpr ((std::is_signed_v) && (PreserveSign)) { - return _mm256_srai_epi{{ intrin_tp[ctype][1] }}(data, shift); + if constexpr(PreserveSign) { + __m256i result = _mm256_srli_epi64(data, shift); + // Create a mask for the sign bit + __m256i sign_mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), data); // Get the sign bit + sign_mask = _mm256_slli_epi64(sign_mask, 64 - shift); // Shift the sign bit mask + // Combine with the sign mask to achieve arithmetic shift + result = _mm256_or_si256(result, sign_mask); + return result; } else { - return _mm256_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm256_srli_epi64(data, shift); } + alternative: + - | + if constexpr(PreserveSign) { + auto const shifted = _mm256_srli_epi64(data, shift); + auto const msb_as_lsb = _mm256_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1); + auto const lsb_shifted = _mm256_slli_epi64(msb_as_lsb, shift); + auto const lsb_all_set = _mm256_sub_epi64(lsb_shifted, _mm256_set1_epi64x(1)); + auto const lsb_mask = _mm256_sub_epi64(_mm256_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1), _mm256_set1_epi64x(1)); + auto const result_msb_as_lsb = _mm256_andnot_si256(lsb_mask, lsb_all_set); + auto const result_msb = _mm256_slli_epi64(result_msb_as_lsb, sizeof(int64_t)*CHAR_BIT - shift); + return _mm256_or_si256(shifted, result_msb); + } else { + return _mm256_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); + } - target_extension: "avx2" ctype: ["int64_t"] lscpu_flags: ["avx2", "avx512f", "avx512vl"] implementation: | if constexpr(PreserveSign) { - return _mm256_srai_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm256_srai_epi64(data, shift); } else { - return _mm256_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm256_srli_epi64(data, shift); } - target_extension: "avx2" - ctype: ["int64_t"] + ctype: ["uint16_t", "uint32_t", "uint64_t", "int16_t", "int32_t"] lscpu_flags: ["avx2"] - includes: [""] implementation: | - if constexpr(PreserveSign) { - auto const shifted = _mm256_srli_epi64(data, shift); - auto const msb_as_lsb = _mm256_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1); - auto const lsb_shifted = _mm256_slli_epi64(msb_as_lsb, shift); - auto const lsb_all_set = _mm256_sub_epi64(lsb_shifted, _mm256_set1_epi64x(1)); - auto const lsb_mask = _mm256_sub_epi64(_mm256_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1), _mm256_set1_epi64x(1)); - auto const result_msb_as_lsb = _mm256_andnot_si256(lsb_mask, lsb_all_set); - auto const result_msb = _mm256_slli_epi64(result_msb_as_lsb, sizeof(int64_t)*CHAR_BIT - shift); - return _mm256_or_si256(shifted, result_msb); + if constexpr ((std::is_signed_v) && (PreserveSign)) { + return _mm256_srai_epi{{ intrin_tp[ctype][1] }}(data, shift); } else { return _mm256_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); } - #Intel - SSE - - target_extension: "sse" - ctype: ["uint16_t", "uint32_t", "int16_t", "int32_t", "uint64_t"] - lscpu_flags: ["sse2"] + - target_extension: "avx2" + ctype: ["uint8_t", "int8_t"] + lscpu_flags: ["avx2"] implementation: | if constexpr ((std::is_signed_v) && (PreserveSign)) { - return _mm_srai_epi{{ intrin_tp[ctype][1] }}(data, shift); + __m256i const valid_bits = _mm256_set1_epi32(0xFF); + __m256i const result0_8 = _mm256_srli_epi32(_mm256_srai_epi32(_mm256_slli_epi32(data, 24), shift), 24); + __m256i const result1_8 = _mm256_and_si256( _mm256_srli_epi32(_mm256_srai_epi32(_mm256_slli_epi32(data, 16), shift), 16), _mm256_slli_epi32(valid_bits, 8)); + __m256i const result2_8 = _mm256_and_si256( _mm256_srli_epi32(_mm256_srai_epi32(_mm256_slli_epi32(data, 8), shift), 8), _mm256_slli_epi32(valid_bits, 16)); + __m256i const result3_8 = _mm256_and_si256( _mm256_srai_epi32(data, shift), _mm256_slli_epi32(valid_bits, 24)); + return _mm256_or_si256(_mm256_or_si256(result0_8, result1_8), _mm256_or_si256(result2_8, result3_8)); } else { - return _mm_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm256_and_si256( + _mm256_srli_epi32(data, shift), + _mm256_set1_epi8((static_cast(-1))>>shift) + ); } + #Intel - SSE - target_extension: "sse" ctype: ["int64_t"] lscpu_flags: ["sse2", "avx512f", "avx512vl"] @@ -718,6 +1055,18 @@ definitions: ctype: ["int64_t"] lscpu_flags: ["sse2", "avx2"] implementation: | + if constexpr(PreserveSign) { + __m128i result = _mm_srli_epi64(data, shift); + // Create a mask for the sign bit + __m128i sign_mask = _mm_cmpgt_epi64(_mm_setzero_si128(), data); // Get the sign bit + sign_mask = _mm_slli_epi64(sign_mask, 64 - shift); // Shift the sign bit mask + // Combine with the sign mask to achieve arithmetic shift + result = _mm_or_si128(result, sign_mask); + return result; + } else { + return _mm_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); + } + alternative: | if constexpr(PreserveSign) { auto const shifted = _mm_srli_epi64(data, shift); auto const msb_as_lsb = _mm_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1); @@ -730,6 +1079,32 @@ definitions: } else { return _mm_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); } + - target_extension: "sse" + ctype: ["uint16_t", "uint32_t", "int16_t", "int32_t", "uint64_t"] + lscpu_flags: ["sse2"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + return _mm_srai_epi{{ intrin_tp[ctype][1] }}(data, shift); + } else { + return _mm_srli_epi{{ intrin_tp[ctype][1] }}(data, shift); + } + - target_extension: "sse" + ctype: ["uint8_t", "int8_t"] + lscpu_flags: ["sse2"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m128i const valid_bits = _mm_set1_epi32(0xFF); + __m128i const result0_8 = _mm_srli_epi32(_mm_srai_epi32(_mm_slli_epi32(data, 24), shift), 24); + __m128i const result1_8 = _mm_and_si128( _mm_srli_epi32(_mm_srai_epi32(_mm_slli_epi32(data, 16), shift), 16), _mm_slli_epi32(valid_bits, 8)); + __m128i const result2_8 = _mm_and_si128( _mm_srli_epi32(_mm_srai_epi32(_mm_slli_epi32(data, 8), shift), 8), _mm_slli_epi32(valid_bits, 16)); + __m128i const result3_8 = _mm_and_si128( _mm_srai_epi32(data, shift), _mm_slli_epi32(valid_bits, 24)); + return _mm_or_si128(_mm_or_si128(result0_8, result1_8), _mm_or_si128(result2_8, result3_8)); + } else { + return _mm_and_si128( + _mm_srli_epi32(data, shift), + _mm_set1_epi8((static_cast(-1))>>shift) + ); + } # - target_extension: "sse" # ctype: ["float"] # lscpu_flags: ["sse2"] @@ -829,6 +1204,18 @@ testing: storeu(test_result, shift_right(vec, shift)); test_helper.synchronize(); allOk &= test_helper.validate(); + if (!allOk) { + std::cerr << "Error with " << tsl::type_name() << "(" << sizeof(typename Vec::base_type) << " * " << Vec::vector_element_count() << ")" << std::endl; + for (auto x = 0; x < Vec::vector_element_count(); ++x) { + if (test_helper.result_target()[x] != expected_result[x]) { + std::cerr << "==================================== START ====================================" << std::endl; + std::cerr << "data[" << x << "] = " << +(ref_data[i+x]) << " >> " << +shift << std::endl; + std::cerr << "expected_result[" << x << "] = " << +(expected_result[x]) << std::endl; + std::cerr << "test_result[" << x << "] = " << +(test_helper.result_target()[x]) << std::endl; + std::cerr << "==================================== END ====================================" << std::endl; + } + } + } } } return allOk; @@ -888,26 +1275,213 @@ definitions: } else { return _mm512_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); } - #"return _mm512_srav_epi{{ intrin_tp[ctype][1] }}(data, shift);" - target_extension: "avx512" ctype: ["uint16_t", "int16_t"] - lscpu_flags: ["avx512bw"] + lscpu_flags: ["avx512f"] implementation: | if constexpr ((std::is_signed_v) && (PreserveSign)) { - return _mm512_srav_epi{{ intrin_tp[ctype][1] }}(data, shift); + __m512i const shift0_16 = _mm512_and_si512(shift, _mm512_set1_epi32(0xFFFF)); + __m512i const shift1_16 = _mm512_srli_epi32(shift, 16); + __m512i const result1_16 = _mm512_and_si512(_mm512_srav_epi32(data, shift1_16), _mm512_set1_epi32(0xFFFF0000)); + __m512i const result0_16 = _mm512_srli_epi32(_mm512_srav_epi32(_mm512_slli_epi32(data, 16), shift0_16), 16); + return _mm512_or_si512(result0_16, result1_16); } else { - return _mm512_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); + __m512i const shift0_16 = _mm512_and_si512(shift, _mm512_set1_epi32(0xFFFF)); + __m512i const shift1_16 = _mm512_srli_epi32(shift, 16); + __m512i const result0_16 = _mm512_srlv_epi32(_mm512_and_si512(data, _mm512_set1_epi32(0xFFFF)), shift0_16); + __m512i const result1_16 = _mm512_and_si512(_mm512_srlv_epi32(data, shift1_16), _mm512_set1_epi32(0xFFFF0000)); + return _mm512_or_si512(result0_16, result1_16); } - #Intel - AVX2 - - target_extension: "avx2" + alternative: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + // extract the lower and upper 256 bit + __m256i const lower256_16 = _mm512_extracti32x8_epi32(data, 0); + __m256i const shift_lower_256_16 = _mm512_extracti32x8_epi32(shift, 0); + __m256i const upper256_16 = _mm512_extracti32x8_epi32(data, 1); + __m256i const shift_upper_256_16 = _mm512_extracti32x8_epi32(shift, 1); + // Sign extend packed 16-bit integers in a to packed 32-bit integers + __m512i const lower256_32 = _mm512_cvtepi16_epi32(lower256_16); + __m512i const shift_lower_256_32 = _mm512_cvtepi16_epi32(shift_lower_256_16); + __m512i const upper256_32 = _mm512_cvtepi16_epi32(upper256_16); + __m512i const shift_upper_256_32 = _mm512_cvtepi16_epi32(shift_upper_256_16); + // Shift the 32-bit integers to the right while preserving the sign + __m512i const lower256_32_shifted = _mm512_srav_epi32(lower256_32, shift_lower_256_32); + __m512i const upper256_32_shifted = _mm512_srav_epi32(upper256_32, shift_upper_256_32); + // Convert packed 32-bit integers in a to packed 16-bit integers with truncation + __m256i const lower256_16_shifted = _mm512_cvtepi32_epi16(lower256_32_shifted); + __m256i const upper256_16_shifted = _mm512_cvtepi32_epi16(upper256_32_shifted); + // Merge results and return + __m512i const result = _mm512_zextsi256_si512(lower256_16_shifted); + return _mm512_inserti32x8(result, upper256_16_shifted, 1); + } else { + __m512i data_low = _mm512_and_si512(data, _mm512_set1_epi32(0xFFFF)); + __m512i data_high = _mm512_srli_epi32(data, 16); + __m512i shift_values_low = _mm512_and_si512(shift_values, _mm512_set1_epi32(0xFFFF)); + __m512i shift_values_high = _mm512_srli_epi32(shift_values, 16); + // Perform logical right shift on each 32-bit element + __m512i result_low = _mm512_srlv_epi32(data_low, shift_values_low); + __m512i result_high = _mm512_srlv_epi32(data_high, shift_values_high); + + // Merge results and return + return = _mm512_or_si512(result_low, _mm512_slli_epi32(result_high, 16)); + } + - target_extension: "avx512" + ctype: ["uint8_t", "int8_t"] + lscpu_flags: ["avx512f"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m512i const valid_bits = _mm512_set1_epi32(0xFF); + __m512i const shift0_8 = _mm512_and_si512(shift, valid_bits); + __m512i const shift1_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 8), valid_bits); + __m512i const shift2_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 16), valid_bits); + __m512i const shift3_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 24), valid_bits); + __m512i const result3_8 = _mm512_and_si512(_mm512_srav_epi32(data, shift3_8), _mm512_slli_epi32(valid_bits, 24)); + __m512i const result2_8 = + _mm512_and_si512( + _mm512_srli_epi32( + _mm512_srav_epi32( + _mm512_slli_epi32(data, 8), + shift2_8 + ), + 8 + ), + _mm512_slli_epi32(valid_bits, 16) + ); + __m512i const result1_8 = + _mm512_and_si512( + _mm512_srli_epi32( + _mm512_srav_epi32( + _mm512_slli_epi32(data, 16), + shift1_8 + ), + 16 + ), + _mm512_slli_epi32(valid_bits, 8) + ); + __m512i const result0_8 = + _mm512_and_si512( + _mm512_srli_epi32( + _mm512_srav_epi32( + _mm512_slli_epi32(data, 24), + shift0_8 + ), + 24 + ), + valid_bits + ); + return _mm512_or_si512(_mm512_or_si512(result0_8, result1_8), _mm512_or_si512(result2_8, result3_8)); + } else { + __m512i const valid_bits = _mm512_set1_epi32(0xFF); + __m512i const shift0_8 = _mm512_and_si512(shift, valid_bits); + __m512i const shift1_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 8), valid_bits); + __m512i const shift2_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 16), valid_bits); + __m512i const shift3_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 24), valid_bits); + __m512i const result0_8 = + _mm512_and_si512( + _mm512_srlv_epi32( + _mm512_and_si512(data, valid_bits), + shift0_8 + ), + valid_bits + ); + __m512i const result1_8 = + _mm512_and_si512( + _mm512_srlv_epi32( + _mm512_and_si512( + data, + _mm512_slli_epi32(valid_bits, 8) + ), + shift1_8 + ), + _mm512_slli_epi32(valid_bits, 8) + ); + __m512i const result2_8 = + _mm512_and_si512( + _mm512_srlv_epi32( + _mm512_and_si512( + data, + _mm512_slli_epi32(valid_bits, 16) + ), + shift2_8 + ), + _mm512_slli_epi32(valid_bits, 16) + ); + __m512i const result3_8 = + _mm512_and_si512( + _mm512_srlv_epi32( + _mm512_and_si512( + data, + _mm512_slli_epi32(valid_bits, 24) + ), + shift3_8 + ), + _mm512_slli_epi32(valid_bits, 24) + ); + return _mm512_or_si512(_mm512_or_si512(result0_8, result1_8), _mm512_or_si512(result2_8, result3_8)); + } + alternative: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + // extract the lower and upper 256 bit + __m128i const data0_128_8 = _mm512_extracti32x4_epi32(data, 0); + __m128i const shift0_128_8 = _mm512_extracti32x4_epi32(shift, 0); + __m128i const data1_128_8 = _mm512_extracti32x4_epi32(data, 1); + __m128i const shift1_128_8 = _mm512_extracti32x4_epi32(shift, 1); + __m128i const data2_128_8 = _mm512_extracti32x4_epi32(data, 2); + __m128i const shift2_128_8 = _mm512_extracti32x4_epi32(shift, 2); + __m128i const data3_128_8 = _mm512_extracti32x4_epi32(data, 3); + __m128i const shift3_128_8 = _mm512_extracti32x4_epi32(shift, 3); + + // Sign extend packed 8-bit integers in a to packed 32-bit integers + __m512i const data0_128_32 = _mm512_cvtepi8_epi32(data0_128_8); + __m512i const shift0_128_32 = _mm512_cvtepi8_epi32(shift0_128_8); + __m512i const data1_128_32 = _mm512_cvtepi8_epi32(data1_128_8); + __m512i const shift1_128_32 = _mm512_cvtepi8_epi32(shift1_128_8); + __m512i const data2_128_32 = _mm512_cvtepi8_epi32(data2_128_8); + __m512i const shift2_128_32 = _mm512_cvtepi8_epi32(shift2_128_8); + __m512i const data3_128_32 = _mm512_cvtepi8_epi32(data3_128_8); + __m512i const shift3_128_32 = _mm512_cvtepi8_epi32(shift3_128_8); + // Shift the 32-bit integers to the right while preserving the sign + __m512i const data0_128_32_shifted = _mm512_srav_epi32(data0_128_32, shift0_128_32); + __m512i const data1_128_32_shifted = _mm512_srav_epi32(data1_128_32, shift1_128_32); + __m512i const data2_128_32_shifted = _mm512_srav_epi32(data2_128_32, shift2_128_32); + __m512i const data3_128_32_shifted = _mm512_srav_epi32(data3_128_32, shift3_128_32); + // Convert packed 32-bit integers into packed 8-bit integers with trunctaion + __m128i const result0_128 = _mm512_cvtepi32_epi8(data0_128_32_shifted); + __m128i const result1_128 = _mm512_cvtepi32_epi8(data1_128_32_shifted); + __m128i const result2_128 = _mm512_cvtepi32_epi8(data2_128_32_shifted); + __m128i const result3_128 = _mm512_cvtepi32_epi8(data3_128_32_shifted); + // Merge results and return + __m256i lower256 = _mm256_zextsi128_si256(result0_128); + lower256 = _mm256_insertf128_si256(lower256, result1_128, 1); + __m256i upper256 = _mm256_zextsi128_si256(result2_128); + upper256 = _mm256_insertf128_si256(upper256, result3_128, 1); + __m512i result = _mm512_zextsi128_si512(lower256); + return _mm512_inserti32x8(result, upper256, 1); + } else { + __m512i data0_8 = _mm512_and_si512(data, _mm512_set1_epi32(0xFF)); + __m512i data1_8 = _mm512_and_si512(data, _mm512_set1_epi32(0xFF00)); + __m512i data2_8 = _mm512_and_si512(data, _mm512_set1_epi32(0xFF0000)); + __m512i data3_8 = _mm512_and_si512(data, _mm512_set1_epi32(0xFF000000)); + __m512i shift0_8 = _mm512_and_si512(shift, _mm512_set1_epi32(0xFF)); + __m512i shift1_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 8), _mm512_set1_epi32(0xFF)); + __m512i shift2_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 16), _mm512_set1_epi32(0xFF)); + __m512i shift3_8 = _mm512_and_si512(_mm512_srli_epi32(shift, 24), _mm512_set1_epi32(0xFF)); + __m512i result0 = _mm512_srlv_epi32(data0_8, shift0_8); + __m512i result1 = _mm512_and_si512(_mm512_srlv_epi32(data1_8, shift1_8), _mm512_set1_epi32(0xFF00)); + __m512i result2 = _mm512_and_si512(_mm512_srlv_epi32(data2_8, shift2_8), _mm512_set1_epi32(0xFF0000)); + __m512i result3 = _mm512_and_si512(_mm512_srlv_epi32(data3_8, shift3_8), _mm512_set1_epi32(0xFF000000)); + return _mm512_or_si512(_mm512_or_si512(result0, result1), _mm512_or_si512(result2, result3)); + } + - target_extension: "avx512" ctype: ["uint16_t", "int16_t"] - lscpu_flags: ["avx512bw", "avx512vl"] + lscpu_flags: ["avx512bw"] implementation: | if constexpr ((std::is_signed_v) && (PreserveSign)) { - return _mm256_srav_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm512_srav_epi{{ intrin_tp[ctype][1] }}(data, shift); } else { - return _mm256_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm512_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); } + #Intel - AVX2 - target_extension: "avx2" ctype: ["uint32_t", "int32_t", "uint64_t"] lscpu_flags: ["avx2"] @@ -921,20 +1495,171 @@ definitions: return _mm256_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); } } + - target_extension: "avx2" + ctype: ["int64_t"] + lscpu_flags: ["avx2"] + implementation: | + if constexpr (PreserveSign) { + __m256i result = _mm256_srlv_epi64(data, shift); + __m256i sign_mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), data); + sign_mask = _mm256_sllv_epi64(sign_mask, _mm256_sub_epi64(_mm256_set1_epi64x(64), shift)); + result = _mm256_or_si256(result, sign_mask); + return result; + } else { + return _mm256_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); + } - target_extension: "avx2" ctype: ["int64_t"] lscpu_flags: ["avx2", "avx512f", "avx512vl"] implementation: | if constexpr (PreserveSign) { + return _mm256_srav_epi64(data, shift); + } else { + return _mm256_srlv_epi64(data, shift); + } + - target_extension: "avx2" + ctype: ["uint16_t", "int16_t"] + lscpu_flags: ["avx2"] + specialization_comment: | + To realize a arithmetic right shift, we take the following steps: + 1. Get the lower 16 bits of the shift values into 32-bit values (shift0_16) + 2. Get the upper 16 bits of the shift values into 32-bit values (shift1_16) + 3. Execute a arithmetic 32-bit right shift on the upper 16 bit of every 32-bit element in the register and mask out the lower 16 bit (result1_16) + 4. Shift the lower 16 bit of every 32-bit element to the left by 16 bits. Now, we can execute a 32-bit arithmetic right shift that preserves the sign bit and shift the result to the right by 16 (result0_16). + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m256i const shift0_16 = _mm256_and_si256(shift, _mm256_set1_epi32(0xFFFF)); + __m256i const shift1_16 = _mm256_srli_epi32(shift, 16); + __m256i const result1_16 = _mm256_and_si256(_mm256_srav_epi32(data, shift1_16), _mm256_set1_epi32(0xFFFF0000)); + __m256i const result0_16 = _mm256_srli_epi32(_mm256_srav_epi32(_mm256_slli_epi32(data, 16), shift0_16), 16); + return _mm256_or_si256(result0_16, result1_16); + } else { + __m256i const shift0_16 = _mm256_and_si256(shift, _mm256_set1_epi32(0xFFFF)); + __m256i const shift1_16 = _mm256_srli_epi32(shift, 16); + __m256i const result0_16 = _mm256_srlv_epi32(_mm256_and_si256(data, _mm256_set1_epi32(0xFFFF)), shift0_16); + __m256i const result1_16 = _mm256_and_si256(_mm256_srlv_epi32(data, shift1_16), _mm256_set1_epi32(0xFFFF0000)); + return _mm256_or_si256(result0_16, result1_16); + } + - target_extension: "avx2" + ctype: ["uint16_t", "int16_t"] + lscpu_flags: ["avx512bw", "avx512vl"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { return _mm256_srav_epi{{ intrin_tp[ctype][1] }}(data, shift); } else { return _mm256_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); } + - target_extension: "avx2" + ctype: ["uint8_t", "int8_t"] + lscpu_flags: ["avx2"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m256i const valid_bits = _mm256_set1_epi32(0xFF); + __m256i const shift0_8 = _mm256_and_si256(shift, valid_bits); + __m256i const shift1_8 = _mm256_and_si256(_mm256_srli_epi32(shift, 8), valid_bits); + __m256i const shift2_8 = _mm256_and_si256(_mm256_srli_epi32(shift, 16), valid_bits); + __m256i const shift3_8 = _mm256_and_si256(_mm256_srli_epi32(shift, 24), valid_bits); + __m256i const result3_8 = _mm256_and_si256(_mm256_srav_epi32(data, shift3_8), _mm256_slli_epi32(valid_bits, 24)); + __m256i const result2_8 = + _mm256_and_si256( + _mm256_srli_epi32( + _mm256_srav_epi32( + _mm256_slli_epi32(data, 8), + shift2_8 + ), + 8 + ), + _mm256_slli_epi32(valid_bits, 16) + ); + __m256i const result1_8 = + _mm256_and_si256( + _mm256_srli_epi32( + _mm256_srav_epi32( + _mm256_slli_epi32(data, 16), + shift1_8 + ), + 16 + ), + _mm256_slli_epi32(valid_bits, 8) + ); + __m256i const result0_8 = + _mm256_and_si256( + _mm256_srli_epi32( + _mm256_srav_epi32( + _mm256_slli_epi32(data, 24), + shift0_8 + ), + 24 + ), + valid_bits + ); + return _mm256_or_si256(_mm256_or_si256(result0_8, result1_8), _mm256_or_si256(result2_8, result3_8)); + } else { + __m256i const valid_bits = _mm256_set1_epi32(0xFF); + __m256i const shift0_8 = _mm256_and_si256(shift, valid_bits); + __m256i const shift1_8 = _mm256_and_si256(_mm256_srli_epi32(shift, 8), valid_bits); + __m256i const shift2_8 = _mm256_and_si256(_mm256_srli_epi32(shift, 16), valid_bits); + __m256i const shift3_8 = _mm256_and_si256(_mm256_srli_epi32(shift, 24), valid_bits); + __m256i const result0_8 = + _mm256_and_si256( + _mm256_srlv_epi32( + _mm256_and_si256(data, valid_bits), + shift0_8 + ), + valid_bits + ); + __m256i const result1_8 = + _mm256_and_si256( + _mm256_srlv_epi32( + _mm256_and_si256( + data, + _mm256_slli_epi32(valid_bits, 8) + ), + shift1_8 + ), + _mm256_slli_epi32(valid_bits, 8) + ); + __m256i const result2_8 = + _mm256_and_si256( + _mm256_srlv_epi32( + _mm256_and_si256( + data, + _mm256_slli_epi32(valid_bits, 16) + ), + shift2_8 + ), + _mm256_slli_epi32(valid_bits, 16) + ); + __m256i const result3_8 = + _mm256_and_si256( + _mm256_srlv_epi32( + _mm256_and_si256( + data, + _mm256_slli_epi32(valid_bits, 24) + ), + shift3_8 + ), + _mm256_slli_epi32(valid_bits, 24) + ); + return _mm256_or_si256(_mm256_or_si256(result0_8, result1_8), _mm256_or_si256(result2_8, result3_8)); + } - target_extension: "avx2" ctype: ["int64_t"] lscpu_flags: ["avx2"] includes: [""] implementation: | + if constexpr(PreserveSign) { + __m256i result = _mm256_srlv_epi64(data, shift); + // Create a mask for the sign bit + __m256i sign_mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), data); // Get the sign bit + sign_mask = _mm256_sllv_epi64(sign_mask, _mm256_sub_epi64(_mm256_set1_epi64x(64),shift)); // Shift the sign bit mask + // Combine with the sign mask to achieve arithmetic shift + result = _mm256_or_si256(result, sign_mask); + return result; + } else { + return _mm256_srlv_epi64(data, shift); + } + alternative: | if constexpr (PreserveSign) { auto const shifted = _mm256_srlv_epi64(data, shift); auto const msb_as_lsb = _mm256_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1); @@ -945,7 +1670,7 @@ definitions: auto const result_msb = _mm256_sllv_epi64(result_msb_as_lsb, _mm256_sub_epi64(_mm256_set1_epi64x(sizeof(int64_t)*CHAR_BIT), shift)); return _mm256_or_si256(shifted, result_msb); } else { - return _mm256_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); + return _mm256_srlv_epi64(data, shift); } #Intel - SSE - target_extension: "sse" @@ -962,6 +1687,16 @@ definitions: lscpu_flags: ["sse2", "avx2"] includes: [""] implementation: | + if constexpr (PreserveSign) { + __m128i result = _mm_srlv_epi64(data, shift); + __m128i sign_mask = _mm_cmpgt_epi64(_mm_setzero_si128(), data); + sign_mask = _mm_sllv_epi64(sign_mask, _mm_sub_epi64(_mm_set1_epi64x(64), shift)); + result = _mm_or_si128(result, sign_mask); + return result; + } else { + return _mm_srlv_epi64(data, shift); + } + alternative: | if constexpr (PreserveSign) { auto const shifted = _mm_srlv_epi64(data, shift); auto const msb_as_lsb = _mm_srli_epi64(data, sizeof(int64_t)*CHAR_BIT-1); @@ -983,6 +1718,117 @@ definitions: } else { return _mm_srlv_epi{{ intrin_tp[ctype][1] }}(data, shift); } + - target_extension: "sse" + ctype: ["uint16_t", "int16_t"] + lscpu_flags: ["sse2", "avx2"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m128i const shift0_16 = _mm_and_si128(shift, _mm_set1_epi32(0xFFFF)); + __m128i const shift1_16 = _mm_srli_epi32(shift, 16); + __m128i const result1_16 = _mm_and_si128(_mm_srav_epi32(data, shift1_16), _mm_set1_epi32(0xFFFF0000)); + __m128i const result0_16 = _mm_srli_epi32(_mm_srav_epi32(_mm_slli_epi32(data, 16), shift0_16), 16); + return _mm_or_si128(result0_16, result1_16); + } else { + __m128i const shift0_16 = _mm_and_si128(shift, _mm_set1_epi32(0xFFFF)); + __m128i const shift1_16 = _mm_srli_epi32(shift, 16); + __m128i const result0_16 = _mm_srlv_epi32(_mm_and_si128(data, _mm_set1_epi32(0xFFFF)), shift0_16); + __m128i const result1_16 = _mm_and_si128(_mm_srlv_epi32(data, shift1_16), _mm_set1_epi32(0xFFFF0000)); + return _mm_or_si128(result0_16, result1_16); + } + - target_extension: "sse" + ctype: ["uint8_t", "int8_t"] + lscpu_flags: ["sse2", "avx2"] + implementation: | + if constexpr ((std::is_signed_v) && (PreserveSign)) { + __m128i const valid_bits = _mm_set1_epi32(0xFF); + __m128i const shift0_8 = _mm_and_si128(shift, valid_bits); + __m128i const shift1_8 = _mm_and_si128(_mm_srli_epi32(shift, 8), valid_bits); + __m128i const shift2_8 = _mm_and_si128(_mm_srli_epi32(shift, 16), valid_bits); + __m128i const shift3_8 = _mm_and_si128(_mm_srli_epi32(shift, 24), valid_bits); + __m128i const result3_8 = _mm_and_si128(_mm_srav_epi32(data, shift3_8), _mm_slli_epi32(valid_bits, 24)); + __m128i const result2_8 = + _mm_and_si128( + _mm_srli_epi32( + _mm_srav_epi32( + _mm_slli_epi32(data, 8), + shift2_8 + ), + 8 + ), + _mm_slli_epi32(valid_bits, 16) + ); + __m128i const result1_8 = + _mm_and_si128( + _mm_srli_epi32( + _mm_srav_epi32( + _mm_slli_epi32(data, 16), + shift1_8 + ), + 16 + ), + _mm_slli_epi32(valid_bits, 8) + ); + __m128i const result0_8 = + _mm_and_si128( + _mm_srli_epi32( + _mm_srav_epi32( + _mm_slli_epi32(data, 24), + shift0_8 + ), + 24 + ), + valid_bits + ); + return _mm_or_si128(_mm_or_si128(result0_8, result1_8), _mm_or_si128(result2_8, result3_8)); + } else { + __m128i const valid_bits = _mm_set1_epi32(0xFF); + __m128i const shift0_8 = _mm_and_si128(shift, valid_bits); + __m128i const shift1_8 = _mm_and_si128(_mm_srli_epi32(shift, 8), valid_bits); + __m128i const shift2_8 = _mm_and_si128(_mm_srli_epi32(shift, 16), valid_bits); + __m128i const shift3_8 = _mm_and_si128(_mm_srli_epi32(shift, 24), valid_bits); + __m128i const result0_8 = + _mm_and_si128( + _mm_srlv_epi32( + _mm_and_si128(data, valid_bits), + shift0_8 + ), + valid_bits + ); + __m128i const result1_8 = + _mm_and_si128( + _mm_srlv_epi32( + _mm_and_si128( + data, + _mm_slli_epi32(valid_bits, 8) + ), + shift1_8 + ), + _mm_slli_epi32(valid_bits, 8) + ); + __m128i const result2_8 = + _mm_and_si128( + _mm_srlv_epi32( + _mm_and_si128( + data, + _mm_slli_epi32(valid_bits, 16) + ), + shift2_8 + ), + _mm_slli_epi32(valid_bits, 16) + ); + __m128i const result3_8 = + _mm_and_si128( + _mm_srlv_epi32( + _mm_and_si128( + data, + _mm_slli_epi32(valid_bits, 24) + ), + shift3_8 + ), + _mm_slli_epi32(valid_bits, 24) + ); + return _mm_or_si128(_mm_or_si128(result0_8, result1_8), _mm_or_si128(result2_8, result3_8)); + } #ARM - NEON - target_extension: "neon" ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"] diff --git a/primitive_data/primitives/mask_ls.yaml b/primitive_data/primitives/mask_ls.yaml index b1f7a737..b1e6b914 100644 --- a/primitive_data/primitives/mask_ls.yaml +++ b/primitive_data/primitives/mask_ls.yaml @@ -413,7 +413,7 @@ testing: for (auto rep = 0; rep < repetition_count; ++rep) { testing::rnd_init(&imask, 1); for (size_t i = 0; i < Vec::vector_element_count(); ++i) { - if ( (imask >> i) & 0b1 == 1 ) { + if (((imask >> i) & 0b1) == 1 ) { reference_result_ptr[i] = i+1; } else { reference_result_ptr[i] = Vec::vector_element_count()*2; @@ -504,7 +504,7 @@ definitions: implementation: | const auto imask = tsl::to_integral(mask); for ( size_t i = 0; i < Vec::vector_element_count(); ++i ) { - if ((imask >> i) & 0b1 == true) { + if (((imask >> i) & 0b1) == true) { memory[i] = data[i]; } }