Skip to content

Commit

Permalink
Made fixes to MultiRotateRight
Browse files Browse the repository at this point in the history
  • Loading branch information
johnplatts committed Feb 6, 2025
1 parent 69f2337 commit 56606cc
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 27 deletions.
8 changes: 7 additions & 1 deletion g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1073,8 +1073,14 @@ A compound shift on 64-bit values:
for(size_t i = 0; i < N; i++) {
uint64_t shift_result = 0;
for(int j = 0; j < 8; j++) {
uint64_t rot_result = (v[i] >> indices[i*8+j]) | (v[i] << (64 - indices[i*8+j]));
uint64_t rot_result =
(static_cast<uint64_t>(v[i]) >> indices[i*8+j]) |
(static_cast<uint64_t>(v[i]) << ((-indices[i*8+j]) & 63));
#if HWY_IS_LITTLE_ENDIAN
shift_result |= (rot_result & 0xff) << (j * 8);
#else
shift_result |= (rot_result & 0xff) << ((j ^ 7) * 8);
#endif
}
r[i] = shift_result;
}
Expand Down
3 changes: 1 addition & 2 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
}
// User-specified mask. Mask=false value is zero.
#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b); \
Expand Down Expand Up @@ -2294,7 +2294,6 @@ HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple)

#undef HWY_SVE_COMPARE_Z


template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedGt(M m, V a, V b) {
// Swap args to reverse comparison
Expand Down
121 changes: 111 additions & 10 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7722,7 +7722,66 @@ HWY_API bool AllBits0(V a) {
#define HWY_NATIVE_MULTIROTATERIGHT
#endif

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>)>
template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>),
class VI_2 = VFromD<Repartition<TFromV<VI>, DFromV<V>>>,
HWY_IF_LANES_D(DFromV<VI>, HWY_MAX_LANES_V(VI_2)),
HWY_IF_V_SIZE_V(V, 8)>
HWY_API V MultiRotateRight(V v, VI idx) {
const DFromV<V> d64;
const Twice<decltype(d64)> dt64;
const Repartition<uint8_t, decltype(d64)> du8;
const Repartition<uint8_t, decltype(dt64)> dt_u8;
const Repartition<uint16_t, decltype(dt64)> dt_u16;
const auto k7 = Set(du8, uint8_t{0x07});
const auto k63 = Set(du8, uint8_t{0x3F});

const auto masked_idx = And(k63, BitCast(du8, idx));

auto byte_idx = ShiftRight<3>(masked_idx);
#if HWY_IS_LITTLE_ENDIAN
const auto hi_byte_idx = Add(byte_idx, Set(du8, uint8_t{1}));
#else
byte_idx = Xor(byte_idx, k7);
const auto hi_byte_idx = Add(byte_idx, k7);
#endif

const auto idx_shift = And(k7, masked_idx);

// Calculate even lanes
const auto even_src = DupEven(ResizeBitCast(dt64, v));
// Expand indexes to pull out 16 bit segments of idx and idx + 1
#if HWY_IS_LITTLE_ENDIAN
const auto even_idx = InterleaveLower(ResizeBitCast(dt_u8, byte_idx),
ResizeBitCast(dt_u8, hi_byte_idx));
#else
const auto even_idx = InterleaveLower(ResizeBitCast(dt_u8, hi_byte_idx),
ResizeBitCast(dt_u8, byte_idx));
#endif
// TableLookupBytes indexes select from within a 16 byte block
const auto even_segments = TableLookupBytes(even_src, even_idx);
// Extract unaligned bytes from 16 bit segments
const auto even_idx_shift = PromoteTo(dt_u16, idx_shift);
const auto extracted_even_bytes =
Shr(BitCast(dt_u16, even_segments), even_idx_shift);

// Extract the even bytes of each 128 bit block and pack into lower 64 bits
#if HWY_IS_LITTLE_ENDIAN
const auto even_lanes = BitCast(
dt64,
ConcatEven(dt_u8, Zero(dt_u8), BitCast(dt_u8, extracted_even_bytes)));
#else
const auto even_lanes = BitCast(
dt64,
ConcatOdd(dt_u8, Zero(dt_u8), BitCast(dt_u8, extracted_even_bytes)));
#endif

return LowerHalf(d64, even_lanes);
}

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>),
class VI_2 = VFromD<Repartition<TFromV<VI>, DFromV<V>>>,
HWY_IF_LANES_D(DFromV<VI>, HWY_MAX_LANES_V(VI_2)),
HWY_IF_V_SIZE_GT_V(V, 8)>
HWY_API V MultiRotateRight(V v, VI idx) {
const DFromV<V> d64;
const Repartition<uint8_t, decltype(d64)> du8;
Expand All @@ -7731,42 +7790,84 @@ HWY_API V MultiRotateRight(V v, VI idx) {
const auto k63 = Set(du8, uint8_t{0x3F});

const auto masked_idx = And(k63, BitCast(du8, idx));
const auto byte_idx = ShiftRight<3>(masked_idx);

auto byte_idx = ShiftRight<3>(masked_idx);
#if HWY_IS_LITTLE_ENDIAN
const auto hi_byte_idx = Add(byte_idx, Set(du8, uint8_t{1}));
#else
byte_idx = Xor(byte_idx, k7);
const auto hi_byte_idx = Add(byte_idx, k7);
#endif

const auto idx_shift = And(k7, masked_idx);

// Calculate even lanes
const auto even_src = DupEven(v);
// Expand indexes to pull out 16 bit segments of idx and idx + 1
const auto even_idx =
InterleaveLower(byte_idx, Add(byte_idx, Set(du8, uint8_t{1})));
#if HWY_IS_LITTLE_ENDIAN
const auto even_idx = InterleaveLower(byte_idx, hi_byte_idx);
#else
const auto even_idx = InterleaveLower(hi_byte_idx, byte_idx);
#endif
// TableLookupBytes indexes select from within a 16 byte block
const auto even_segments = TableLookupBytes(even_src, even_idx);
// Extract unaligned bytes from 16 bit segments
#if HWY_IS_LITTLE_ENDIAN
const auto even_idx_shift = ZipLower(idx_shift, Zero(du8));
#else
const auto even_idx_shift = ZipLower(Zero(du8), idx_shift);
#endif
const auto extracted_even_bytes =
Shr(BitCast(du16, even_segments), even_idx_shift);

// Calculate odd lanes
const auto odd_src = DupOdd(v);
// Expand indexes to pull out 16 bit segments of idx and idx + 1
const auto odd_idx =
InterleaveUpper(du8, byte_idx, Add(byte_idx, Set(du8, uint8_t{1})));
#if HWY_IS_LITTLE_ENDIAN
const auto odd_idx = InterleaveUpper(du8, byte_idx, hi_byte_idx);
#else
const auto odd_idx = InterleaveUpper(du8, hi_byte_idx, byte_idx);
#endif
// TableLookupBytes indexes select from within a 16 byte block
const auto odd_segments = TableLookupBytes(odd_src, odd_idx);
// Extract unaligned bytes from 16 bit segments
#if HWY_IS_LITTLE_ENDIAN
const auto odd_idx_shift = ZipUpper(du16, idx_shift, Zero(du8));
#else
const auto odd_idx_shift = ZipUpper(du16, Zero(du8), idx_shift);
#endif
const auto extracted_odd_bytes =
Shr(BitCast(du16, odd_segments), odd_idx_shift);

// Extract the even bytes of each 128 bit block and pack into lower 64 bits
const auto even_lanes =
BitCast(d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_even_bytes)));
const auto odd_lanes =
BitCast(d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_odd_bytes)));
#if HWY_IS_LITTLE_ENDIAN
const auto even_lanes = BitCast(
d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_even_bytes)));
const auto odd_lanes = BitCast(
d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_odd_bytes)));
#else
const auto even_lanes = BitCast(
d64, ConcatOdd(du8, Zero(du8), BitCast(du8, extracted_even_bytes)));
const auto odd_lanes = BitCast(
d64, ConcatOdd(du8, Zero(du8), BitCast(du8, extracted_odd_bytes)));
#endif
// Interleave at 64 bit level
return InterleaveWholeLower(even_lanes, odd_lanes);
}

#if HWY_TARGET == HWY_RVV

// MultiRotateRight for LMUL=1/2 case on RVV
template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>),
class VI_2 = VFromD<Repartition<TFromV<VI>, DFromV<V>>>,
HWY_IF_POW2_LE_D(DFromV<V>, 0),
HWY_IF_LANES_D(DFromV<VI>, HWY_MAX_LANES_V(VI_2) / 2)>
HWY_API V MultiRotateRight(V v, VI idx) {
return MultiRotateRight(v, ResizeBitCast(Twice<DFromV<VI>>(), idx));
}

#endif

#endif

// ================================================== Operator wrapper
Expand Down
19 changes: 18 additions & 1 deletion hwy/ops/x86_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13691,7 +13691,24 @@ HWY_API V BitShuffle(V v, VI idx) {
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// TODO: Implement MultiRotateRight using _mm_multishift_epi64_epi8
// ------------------------------ MultiShiftRight

#if HWY_TARGET <= HWY_AVX3_DL

#ifdef HWY_NATIVE_MULTIROTATERIGHT
#undef HWY_NATIVE_MULTIROTATERIGHT
#else
#define HWY_NATIVE_MULTIROTATERIGHT
#endif

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>),
HWY_IF_V_SIZE_LE_V(V, 16),
HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)>
HWY_API V MultiRotateRight(V v, VI idx) {
return V{_mm_multishift_epi64_epi8(idx.raw, v.raw)};
}

#endif

// ------------------------------ Lt128

Expand Down
18 changes: 17 additions & 1 deletion hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8795,7 +8795,23 @@ HWY_API V BitShuffle(V v, VI idx) {
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// TODO: Implement MultiRotateRight using _mm256_multishift_epi64_epi8
// ------------------------------ MultiShiftRight

#if HWY_TARGET <= HWY_AVX3_DL

#ifdef HWY_NATIVE_MULTIROTATERIGHT
#undef HWY_NATIVE_MULTIROTATERIGHT
#else
#define HWY_NATIVE_MULTIROTATERIGHT
#endif

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>),
HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)>
HWY_API V MultiRotateRight(V v, VI idx) {
return V{_mm256_multishift_epi64_epi8(idx.raw, v.raw)};
}

#endif

// ------------------------------ LeadingZeroCount

Expand Down
18 changes: 17 additions & 1 deletion hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7556,7 +7556,23 @@ HWY_API V BitShuffle(V v, VI idx) {
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// TODO: Implement MultiRotateRight using _mm512_multishift_epi64_epi8
// ------------------------------ MultiShiftRight

#if HWY_TARGET <= HWY_AVX3_DL

#ifdef HWY_NATIVE_MULTIROTATERIGHT
#undef HWY_NATIVE_MULTIROTATERIGHT
#else
#define HWY_NATIVE_MULTIROTATERIGHT
#endif

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>),
HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)>
HWY_API V MultiRotateRight(V v, VI idx) {
return V{_mm512_multishift_epi64_epi8(idx.raw, v.raw)};
}

#endif

// -------------------- LeadingZeroCount

Expand Down
Loading

0 comments on commit 56606cc

Please sign in to comment.