Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AddLower, PairwiseAdd/Sub and MaskedAbsOr operations #2405

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ from left to right, of the arguments passed to `Create{2-4}`.
than `OddEven(Add(a, b), Sub(a, b))` or `Add(a, OddEven(b, Neg(b)))` on some
targets.

* <code>V **AddLower**(V a, V b)</code>: returns `a[0] + b[0]`
and `a[i]` in all other lanes.

* `V`: `{i,f}` \
<code>V **Neg**(V a)</code>: returns `-a[i]`.

Expand All @@ -548,6 +551,24 @@ from left to right, of the arguments passed to `Create{2-4}`.

* <code>V **AbsDiff**(V a, V b)</code>: returns `|a[i] - b[i]|` in each lane.

* <code>V **PairwiseAdd**(D d, V a, V b)</code>: Add consecutive pairs of elements.
Return the results of a and b interleaved, such that `r[i] = a[i] + a[i+1]` for
even lanes and `r[i] = b[i-1] + b[i]` for odd lanes.

* <code>V **PairwiseSub**(D d, V a, V b)</code>: Subtract consecutive pairs of elements.
Return the results of a and b interleaved, such that `r[i] = a[i+1] - a[i]` for
even lanes and `r[i] = b[i] - b[i-1]` for odd lanes.

* <code>V **PairwiseAdd128**(D d, V a, V b)</code>: Add consecutive pairs of
* elements in a and b, and pack results in 128 bit blocks, such that
`r[i] = a[i] + a[i+1]` for 64 bits, followed by `b[i] + b[i+1]` for next 64
bits and repeated.

* <code>V **PairwiseSub128**(D d, V a, V b)</code>: Subtract consecutive pairs
of elements in a and b, and pack results in 128 bit blocks, such that
`r[i] = a[i] + a[i+1]` for 64 bits, followed by `b[i] + b[i+1]` for next 64
bits and repeated.

* `V`: `{i,u}{8,16,32},f{16,32}`, `VW`: `Vec<RepartitionToWide<DFromV<V>>>` \
<code>VW **SumsOf2**(V v)</code>
returns the sums of 2 consecutive lanes, promoting each sum into a lane of
Expand Down Expand Up @@ -886,6 +907,18 @@ not a concern, these are equivalent to, and potentially more efficient than,
<code>V **MaskedSatSubOr**(V no, M m, V a, V b)</code>: returns `a[i] +
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.
* `V`: `{i,f}` \
<code>V **MaskedAbsOr**(M m, V a, V b)</code>: returns the absolute value of
`a[i]` where m is active and returns `b[i]` otherwise.

#### Zero masked arithmetic

Ops in this section return `0` for `mask=false` lanes. These are equivalent
to, and potentially more efficient than, `IfThenElseZero(m, Abs(a, b));` etc.

* `V`: `{i,f}` \
<code>V **MaskedAbsOrZero**(M m, V a)</code>: returns the absolute value of
`a[i]` where m is active and returns zero otherwise.

#### Shifts

Expand Down
2 changes: 2 additions & 0 deletions hwy/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ using RemovePtr = typename RemovePtrT<T>::type;
hwy::EnableIf<kN * sizeof(T) <= bytes>* = nullptr
#define HWY_IF_V_SIZE_GT(T, kN, bytes) \
hwy::EnableIf<(kN * sizeof(T) > bytes)>* = nullptr
#define HWY_IF_V_SIZE_GE(T, kN, bytes) \
hwy::EnableIf<(kN * sizeof(T) >= bytes)>* = nullptr

#define HWY_IF_LANES(kN, lanes) hwy::EnableIf<(kN == lanes)>* = nullptr
#define HWY_IF_LANES_LE(kN, lanes) hwy::EnableIf<(kN <= lanes)>* = nullptr
Expand Down
63 changes: 63 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(b, m, a); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_z(m, a); \
}

// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -862,6 +871,12 @@ HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs)
#endif // HWY_SVE_HAVE_2

// ------------------------------ MaskedAbsOr
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs)

// ------------------------------ MaskedAbsOrZero
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbsOrZero, abs)

// ================================================== ARITHMETIC

// Per-target flags to prevent generic_ops-inl.h defining Add etc.
Expand Down Expand Up @@ -4756,6 +4771,21 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
return IfThenElse(IsNegative(v), yes, no);
}

// ------------------------------ AddLower

#ifdef HWY_NATIVE_ADD_LOWER
#undef HWY_NATIVE_ADD_LOWER
#endif

#define HWY_NATIVE_ADD_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(svptrue_pat_b##BITS(SV_VL1), a, b); \
}

HWY_SVE_FOREACH(HWY_NATIVE_ADD_LOWER, AddLower, add)
#undef HWY_NATIVE_ADD_LOWER

// ------------------------------ AverageRound (ShiftRight)

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
Expand Down Expand Up @@ -5813,6 +5843,38 @@ HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) {
return detail::InterleaveOdd(lo, hi);
}

// ------------------------------ PairwiseAdd/PairwiseSub
#if HWY_TARGET != HWY_SCALAR
#if HWY_SVE_HAVE_2 || HWY_IDE

#ifdef HWY_NATIVE_PAIRWISE_ADD
#undef HWY_NATIVE_PAIRWISE_ADD
#else
#define HWY_NATIVE_PAIRWISE_ADD
#endif

namespace detail {
#define HWY_SVE_SV_PAIRWISE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(HWY_SVE_PTRUE(BITS), a, b); \
}

HWY_SVE_FOREACH(HWY_SVE_SV_PAIRWISE_ADD, PairwiseAdd, addp)
#undef HWY_SVE_SV_PAIRWISE_ADD
} // namespace detail

// Pairwise add returning interleaved output of a and b
template <class D, class V, HWY_IF_LANES_GT_D(D, 1)>
HWY_API V PairwiseAdd(D d, V a, V b) {
return detail::PairwiseAdd(d, a, b);
}

#endif // HWY_SVE_HAVE_2
#endif // HWY_TARGET != HWY_SCALAR

// ------------------------------ WidenMulPairwiseAdd

template <size_t N, int kPow2>
Expand Down Expand Up @@ -6296,6 +6358,7 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_RETV_ARGPVV
#undef HWY_SVE_RETV_ARGV
#undef HWY_SVE_RETV_ARGVN
#undef HWY_SVE_RETV_ARGMV_M
#undef HWY_SVE_RETV_ARGVV
#undef HWY_SVE_RETV_ARGVVV
#undef HWY_SVE_T
Expand Down
105 changes: 105 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,18 @@ HWY_API V SaturatedAbs(V v) {

#endif

// ------------------------------ MaskedAbsOr
template <class V, HWY_IF_SIGNED_V(V), class M>
HWY_API V MaskedAbsOr(M m, V v, V no) {
return IfThenElse(m, Abs(v), no);
}

// ------------------------------ MaskedAbsOrZero
template <class V, HWY_IF_SIGNED_V(V), class M>
HWY_API V MaskedAbsOrZero(M m, V v) {
return IfThenElseZero(m, Abs(v));
}

// ------------------------------ Reductions

// Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled,
Expand Down Expand Up @@ -970,6 +982,22 @@ HWY_API VFromD<RebindToSigned<DFromV<V>>> FloorInt(V v) {

#endif // HWY_NATIVE_CEIL_FLOOR_INT

#if (defined(HWY_NATIVE_ADD_LOWER) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ Addlower
#ifdef HWY_NATIVE_ADD_LOWER
#undef HWY_NATIVE_ADD_LOWER
#else
#define HWY_NATIVE_ADD_LOWER
#endif
template <class V>
HWY_API V AddLower(V a, V b) {
const DFromV<V> d;
const MFromD<DFromV<V>> LowerMask = FirstN(d, 1);
return IfThenElse(LowerMask, Add(a, b), a);
}
#endif

// ------------------------------ MulByPow2/MulByFloorPow2

#if (defined(HWY_NATIVE_MUL_BY_POW2) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -1991,6 +2019,83 @@ HWY_API void StoreInterleaved4(VFromD<D> part0, VFromD<D> part1,

#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED

// ------------------------------ PairwiseAdd/PairwiseSub
#if (defined(HWY_NATIVE_PAIRWISE_ADD) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_PAIRWISE_ADD
#undef HWY_NATIVE_PAIRWISE_ADD
#else
#define HWY_NATIVE_PAIRWISE_ADD
#endif

template <class D, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V PairwiseAdd(D d, V a, V b) {
return Add(InterleaveEven(d, a, b), InterleaveOdd(d, a, b));
}

#endif

#if (defined(HWY_NATIVE_PAIRWISE_SUB) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_PAIRWISE_SUB
#undef HWY_NATIVE_PAIRWISE_SUB
#else
#define HWY_NATIVE_PAIRWISE_SUB
#endif

template <class D, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V PairwiseSub(D d, V a, V b) {
return Sub(InterleaveOdd(d, a, b), InterleaveEven(d, a, b));
}

#endif

#if (defined(HWY_NATIVE_PAIRWISE_ADD_128) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_PAIRWISE_ADD_128
#undef HWY_NATIVE_PAIRWISE_ADD_128
#else
#define HWY_NATIVE_PAIRWISE_ADD_128
#endif

template <class D>
using IndicesFromD = decltype(IndicesFromVec(D(), Zero(RebindToUnsigned<D>())));

// Generate indices to convert
// a[0]+a[1], b[0]+b[1], a[2]+a[3], b[2]+b[3] and so on
// to
// Interleaved 64 bits of a[0]+a[1], a[2]+a[3], ...
// and 64 bits of b[0]+b[1], b[2]+b[3], ... and so on
template <typename V, typename D = DFromV<V>, typename T = TFromD<D>,
const size_t N = HWY_LANES(T)>
constexpr IndicesFromD<D> Pairwise128Indices(D d) {
const size_t block_len = 8 / sizeof(T);
const size_t n_blocks = N / block_len;
TFromD<RebindToUnsigned<D>> indices[N] = {0};

TFromD<RebindToUnsigned<D>> even = 0, odd = 1;
for (size_t block = 0; block < n_blocks; block += 2) {
for (size_t index = 0; index < block_len; ++index, even += 2) {
indices[block * block_len + index] = even;
}
for (size_t index = 0; index < block_len; ++index, odd += 2) {
indices[(block + 1) * block_len + index] = odd;
}
}
return SetTableIndices(d, indices);
}

// Pairwise add with output in 128 bit blocks of a and b.
template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_GE_D(D, 16)>
HWY_API V PairwiseAdd128(D d, V a, V b) {
return TableLookupLanes(PairwiseAdd(d, a, b), Pairwise128Indices<V>(d));
}

// Pairwise sub with output in 128 bit blocks of a and b.
template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_GE_D(D, 16)>
HWY_API V PairwiseSub128(D d, V a, V b) {
return TableLookupLanes(PairwiseSub(d, a, b), Pairwise128Indices<V>(d));
}

#endif

// Load/StoreInterleaved for special floats. Requires HWY_GENERIC_IF_EMULATED_D
// is defined such that it is true only for types that actually require these
// generic implementations.
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/shared-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ HWY_API bool IsAligned(D d, T* ptr) {
HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromD<D>, HWY_MAX_LANES_D(D), bytes)
#define HWY_IF_V_SIZE_GT_D(D, bytes) \
HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromD<D>, HWY_MAX_LANES_D(D), bytes)
#define HWY_IF_V_SIZE_GE_D(D, bytes) \
HWY_IF_V_SIZE_GE(hwy::HWY_NAMESPACE::TFromD<D>, HWY_MAX_LANES_D(D), bytes)

// Same, but with a vector argument. ops/*-inl.h define their own TFromV.
#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromV<V>)
Expand Down
Loading
Loading