Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AddLower, PairwiseAdd/Sub and MaskedAbsOr operations #2405

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
26 changes: 26 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,14 @@ from left to right, of the arguments passed to `Create{2-4}`.

* <code>V **AbsDiff**(V a, V b)</code>: returns `|a[i] - b[i]|` in each lane.

* <code>V **PairwiseAdd**(D d, V a, V b)</code>: Add consecutive pairs of elements.
Return the results of a and b interleaved, such that `r[i] = a[i] + a[i+1]` for
even lanes and `r[i] = b[i-1] + b[i]` for odd lanes.

* <code>V **PairwiseSub**(D d, V a, V b)</code>: Subtract consecutive pairs of elements.
Return the results of a and b interleaved, such that `r[i] = a[i+1] - a[i]` for
even lanes and `r[i] = b[i] - b[i-1]` for odd lanes.

* `V`: `{i,u}{8,16,32},f{16,32}`, `VW`: `Vec<RepartitionToWide<DFromV<V>>>` \
<code>VW **SumsOf2**(V v)</code>
returns the sums of 2 consecutive lanes, promoting each sum into a lane of
Expand Down Expand Up @@ -930,11 +938,19 @@ not a concern, these are equivalent to, and potentially more efficient than,
* <code>V **MaskedMulAddOr**(V no, M m, V mul, V x, V add)</code>: returns
`mul[i] * x[i] + add[i]` or `no[i]` if `m[i]` is false.

* `V`: `{i,f}` \
<code>V **MaskedAbsOr**(V no, M m, V a)</code>: returns the absolute value of
`a[i]` where m is active and returns `no[i]` otherwise.

#### Zero masked arithmetic

All ops in this section return `0` for `mask=false` lanes. These are equivalent
to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc.

* `V`: `{i,f}` \
<code>V **MaskedAbs**(M m, V a)</code>: returns the absolute value of
`a[i]` where m is active and returns zero otherwise.

* <code>V **MaskedMax**(M m, V a, V b)</code>: returns `Max(a, b)[i]`
or `zero` if `m[i]` is false.
* <code>V **MaskedAdd**(M m, V a, V b)</code>: returns `a[i] + b[i]`
Expand Down Expand Up @@ -2196,6 +2212,16 @@ All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
}
```

* <code>V **PairwiseAdd128**(D d, V a, V b)</code>: Add consecutive pairs of
elements in a and b, and pack results in 128 bit blocks, such that
`r[i] = a[i] + a[i+1]` for 64 bits, followed by `b[i] + b[i+1]` for next 64
bits and repeated.

* <code>V **PairwiseSub128**(D d, V a, V b)</code>: Subtract consecutive pairs
of elements in a and b, and pack results in 128 bit blocks, such that
`r[i] = a[i] + a[i+1]` for 64 bits, followed by `b[i] + b[i+1]` for next 64
bits and repeated.

#### Interleave

Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
Expand Down
44 changes: 44 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_m(no, m, a); \
}
#define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(m, v); \
Expand Down Expand Up @@ -912,6 +917,12 @@ HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs)
#endif // HWY_SVE_HAVE_2

// ------------------------------ MaskedAbsOr
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs)

// ------------------------------ MaskedAbs
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbs, abs)

// ================================================== ARITHMETIC

// Per-target flags to prevent generic_ops-inl.h defining Add etc.
Expand Down Expand Up @@ -6201,6 +6212,38 @@ HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) {
return detail::InterleaveOdd(lo, hi);
}

// ------------------------------ PairwiseAdd/PairwiseSub
#if HWY_TARGET != HWY_SCALAR
#if HWY_SVE_HAVE_2 || HWY_IDE

#ifdef HWY_NATIVE_PAIRWISE_ADD
#undef HWY_NATIVE_PAIRWISE_ADD
#else
#define HWY_NATIVE_PAIRWISE_ADD
#endif

namespace detail {
#define HWY_SVE_SV_PAIRWISE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(HWY_SVE_PTRUE(BITS), a, b); \
}

HWY_SVE_FOREACH(HWY_SVE_SV_PAIRWISE_ADD, PairwiseAdd, addp)
#undef HWY_SVE_SV_PAIRWISE_ADD
} // namespace detail

// Pairwise add returning interleaved output of a and b
template <class D, class V, HWY_IF_LANES_GT_D(D, 1)>
HWY_API V PairwiseAdd(D d, V a, V b) {
return detail::PairwiseAdd(d, a, b);
}

#endif // HWY_SVE_HAVE_2
#endif // HWY_TARGET != HWY_SCALAR

// ------------------------------ WidenMulPairwiseAdd

template <size_t N, int kPow2>
Expand Down Expand Up @@ -6727,6 +6770,7 @@ HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount,
#undef HWY_SVE_RETV_ARGPVV
#undef HWY_SVE_RETV_ARGV
#undef HWY_SVE_RETV_ARGVN
#undef HWY_SVE_RETV_ARGMV_M
#undef HWY_SVE_RETV_ARGVV
#undef HWY_SVE_RETV_ARGVVV
#undef HWY_SVE_RETV_ARGMVVV_Z
Expand Down
141 changes: 141 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,18 @@ HWY_API V SaturatedAbs(V v) {

#endif

// ------------------------------ MaskedAbsOr
template <class V, HWY_IF_SIGNED_V(V), class M>
HWY_API V MaskedAbsOr(V no, M m, V v) {
return IfThenElse(m, Abs(v), no);
}

// ------------------------------ MaskedAbs
template <class V, HWY_IF_SIGNED_V(V), class M>
HWY_API V MaskedAbs(M m, V v) {
return IfThenElseZero(m, Abs(v));
}

// ------------------------------ Reductions

// Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled,
Expand Down Expand Up @@ -2260,6 +2272,35 @@ HWY_API void StoreInterleaved4(VFromD<D> part0, VFromD<D> part1,

#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED

// ------------------------------ PairwiseAdd/PairwiseSub
#if (defined(HWY_NATIVE_PAIRWISE_ADD) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_PAIRWISE_ADD
#undef HWY_NATIVE_PAIRWISE_ADD
#else
#define HWY_NATIVE_PAIRWISE_ADD
#endif

template <class D, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V PairwiseAdd(D d, V a, V b) {
return Add(InterleaveEven(d, a, b), InterleaveOdd(d, a, b));
}

#endif

#if (defined(HWY_NATIVE_PAIRWISE_SUB) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_PAIRWISE_SUB
#undef HWY_NATIVE_PAIRWISE_SUB
#else
#define HWY_NATIVE_PAIRWISE_SUB
#endif

template <class D, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V PairwiseSub(D d, V a, V b) {
return Sub(InterleaveOdd(d, a, b), InterleaveEven(d, a, b));
}

#endif

// Load/StoreInterleaved for special floats. Requires HWY_GENERIC_IF_EMULATED_D
// is defined such that it is true only for types that actually require these
// generic implementations.
Expand Down Expand Up @@ -7259,6 +7300,106 @@ HWY_API V Per4LaneBlockShuffle(V v) {
}
#endif

// ------------------------------ PairwiseAdd128/PairwiseSub128
// (Per4LaneBlockShuffle)
#if (defined(HWY_NATIVE_PAIRWISE_ADD_128) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_PAIRWISE_ADD_128
#undef HWY_NATIVE_PAIRWISE_ADD_128
#else
#define HWY_NATIVE_PAIRWISE_ADD_128
#endif

namespace detail {

// detail::BlockwiseConcatOddEven(d, v) returns the even lanes of each block of
// v followed by the odd lanes of v
#if HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV
template <class D, HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2) | (1 << 4)),
HWY_IF_V_SIZE_GT_D(D, 8)>
static HWY_INLINE HWY_MAYBE_UNUSED Vec<D> BlockwiseConcatOddEven(D d,
Vec<D> v) {
#if HWY_TARGET == HWY_RVV
const ScalableTag<uint64_t, HWY_MAX(HWY_POW2_D(D), 0)> du64;
#else
const Repartition<uint64_t, DFromV<decltype(v)>> du64;
#endif

const Repartition<TFromD<decltype(d)>, decltype(du64)> d_concat;
const auto v_to_concat = ResizeBitCast(d_concat, v);

const auto evens = ConcatEven(d, v_to_concat, v_to_concat);
const auto odds = ConcatOdd(d, v_to_concat, v_to_concat);
return ResizeBitCast(
d, InterleaveWholeLower(BitCast(du64, evens), BitCast(du64, odds)));
}

#else // !(HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV)

template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_GT_D(D, 8)>
static HWY_INLINE HWY_MAYBE_UNUSED Vec<D> BlockwiseConcatOddEven(D d,
Vec<D> v) {
#if HWY_TARGET == HWY_SSE2
const RebindToUnsigned<decltype(d)> du;
const RebindToSigned<RepartitionToWide<decltype(du)>> dw;

const auto vu = BitCast(du, v);
return BitCast(
d, OrderedDemote2To(du, PromoteEvenTo(dw, vu), PromoteOddTo(dw, vu)));
#else
const Repartition<uint8_t, decltype(d)> du8;
const auto idx =
BitCast(d, Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7,
9, 11, 13, 15));
return TableLookupBytes(v, idx);
#endif
}

template <class D, HWY_IF_T_SIZE_D(D, 2), HWY_IF_V_SIZE_GT_D(D, 8)>
static HWY_INLINE HWY_MAYBE_UNUSED Vec<D> BlockwiseConcatOddEven(D d,
Vec<D> v) {
#if HWY_TARGET == HWY_SSE2
const RebindToSigned<decltype(d)> di;
const RepartitionToWide<decltype(di)> dw;
const auto vi = BitCast(di, v);
return BitCast(
d, OrderedDemote2To(di, PromoteEvenTo(dw, vi), PromoteOddTo(dw, vi)));
#else
const Repartition<uint8_t, decltype(d)> du8;
const auto idx = BitCast(d, Dup128VecFromValues(du8, 0, 1, 4, 5, 8, 9, 12, 13,
2, 3, 6, 7, 10, 11, 14, 15));
return TableLookupBytes(v, idx);
#endif
}

template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_V_SIZE_GT_D(D, 8)>
static HWY_INLINE HWY_MAYBE_UNUSED Vec<D> BlockwiseConcatOddEven(D /*d*/,
Vec<D> v) {
return Per4LaneBlockShuffle<3, 1, 2, 0>(v);
}
#endif // HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV

template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_V_SIZE_GT_D(D, 8)>
static HWY_INLINE HWY_MAYBE_UNUSED Vec<D> BlockwiseConcatOddEven(D /*d*/,
Vec<D> v) {
return v;
}

} // namespace detail

// Pairwise add with output in 128 bit blocks of a and b.
template <class D, HWY_IF_PAIRWISE_ADD_128_D(D)>
HWY_API Vec<D> PairwiseAdd128(D d, Vec<D> a, Vec<D> b) {
return detail::BlockwiseConcatOddEven(d, PairwiseAdd(d, a, b));
}

// Pairwise sub with output in 128 bit blocks of a and b.
template <class D, HWY_IF_PAIRWISE_SUB_128_D(D)>
HWY_API Vec<D> PairwiseSub128(D d, Vec<D> a, Vec<D> b) {
return detail::BlockwiseConcatOddEven(d, PairwiseSub(d, a, b));
}

#endif

// ------------------------------ Blocks

template <class D>
Expand Down
6 changes: 6 additions & 0 deletions hwy/ops/shared-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,12 @@ HWY_API bool IsAligned(D d, T* ptr) {
#define HWY_IF_MULADDSUB_V(V) \
HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV<V>, 1)

#undef HWY_IF_PAIRWISE_ADD_128_D
#define HWY_IF_PAIRWISE_ADD_128_D(D) HWY_IF_V_SIZE_GT_D(D, 8)

#undef HWY_IF_PAIRWISE_SUB_128_D
#define HWY_IF_PAIRWISE_SUB_128_D(D) HWY_IF_V_SIZE_GT_D(D, 8)

// HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V is used to disable the default
// implementation of unsigned to signed DemoteTo/ReorderDemote2To in
// generic_ops-inl.h for at least some of the unsigned to signed demotions on
Expand Down
58 changes: 58 additions & 0 deletions hwy/ops/x86_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3909,6 +3909,64 @@ HWY_API Vec128<double> AddSub(Vec128<double> a, Vec128<double> b) {
}
#endif // HWY_TARGET <= HWY_SSSE3

// ------------------------------ PairwiseAdd128/PairwiseSub128

// Need to use the default implementation of PairwiseAdd128/PairwiseSub128 in
// generic_ops-inl.h for U8/I8/F16/I64/U64 vectors and 64-byte vectors

#if HWY_TARGET <= HWY_SSSE3

#undef HWY_IF_PAIRWISE_ADD_128_D
#undef HWY_IF_PAIRWISE_SUB_128_D
#define HWY_IF_PAIRWISE_ADD_128_D(D) \
hwy::EnableIf<( \
HWY_MAX_LANES_D(D) > (32 / sizeof(hwy::HWY_NAMESPACE::TFromD<D>)) || \
(HWY_MAX_LANES_D(D) > (8 / sizeof(hwy::HWY_NAMESPACE::TFromD<D>)) && \
!(hwy::IsSameEither<hwy::HWY_NAMESPACE::TFromD<D>, int16_t, \
uint16_t>() || \
sizeof(hwy::HWY_NAMESPACE::TFromD<D>) == 4 || \
hwy::IsSame<hwy::HWY_NAMESPACE::TFromD<D>, double>())))>* = nullptr
#define HWY_IF_PAIRWISE_SUB_128_D(D) HWY_IF_PAIRWISE_ADD_128_D(D)

template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_UI16_D(D)>
HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) {
return VFromD<D>{_mm_hadd_epi16(a.raw, b.raw)};
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_UI16_D(D)>
HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) {
const DFromV<decltype(a)> d;
const RebindToSigned<decltype(d)> di;
return BitCast(d, Neg(BitCast(di, VFromD<D>{_mm_hsub_epi16(a.raw, b.raw)})));
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_UI32_D(D)>
HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) {
return VFromD<D>{_mm_hadd_epi32(a.raw, b.raw)};
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_UI32_D(D)>
HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) {
const DFromV<decltype(a)> d;
const RebindToSigned<decltype(d)> di;
return BitCast(d, Neg(BitCast(di, VFromD<D>{_mm_hsub_epi32(a.raw, b.raw)})));
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)>
HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) {
return VFromD<D>{_mm_hadd_ps(a.raw, b.raw)};
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)>
HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) {
return Neg(VFromD<D>{_mm_hsub_ps(a.raw, b.raw)});
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F64_D(D)>
HWY_API VFromD<D> PairwiseAdd128(D /*d*/, VFromD<D> a, VFromD<D> b) {
return VFromD<D>{_mm_hadd_pd(a.raw, b.raw)};
}
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F64_D(D)>
HWY_API VFromD<D> PairwiseSub128(D /*d*/, VFromD<D> a, VFromD<D> b) {
return Neg(VFromD<D>{_mm_hsub_pd(a.raw, b.raw)});
}

#endif // HWY_TARGET <= HWY_SSSE3

// ------------------------------ SumsOf8
template <size_t N>
HWY_API Vec128<uint64_t, N / 8> SumsOf8(const Vec128<uint8_t, N> v) {
Expand Down
Loading
Loading