Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SatWidenMulPairwiseAccumulate and SatWidenMulAccumFixedPoint ops #2055

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,23 @@ All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
`a_u[2*i+1]*b_i[2*i+1] + a_u[2*i+0]*b_i[2*i+0]`, saturated to the range of
`TFromD<D>`.

* `DW`: `i32`, `D`: `Rebind<MakeNarrow<TFromD<DW>>, DW>`,
`VW`: `Vec<DW>`, `V`: `Vec<D>` \
<code>Vec&lt;D&gt; **SatWidenMulPairwiseAccumulate**(DW, V a, V b, VW sum)
</code>: widens `a[i]` and `b[i]` to `TFromD<DI>` and computes
`a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] + sum[i]`, saturated to the range of
`TFromD<DW>`.

* `DW`: `i32`, `D`: `Rebind<MakeNarrow<TFromD<DW>>, DW>`,
`VW`: `Vec<DW>`, `V`: `Vec<D>` \
<code>VW **SatWidenMulAccumFixedPoint**(DW, V a, V b, VW sum)**</code>:
First, widens `a` and `b` to `TFromD<DW>`, then adds `a[i] * b[i] * 2` to
`sum[i]`, saturated to the range of `TFromD<DW>`.

If `a[i] == LimitsMin<TFromD<D>>() && b[i] == LimitsMin<TFromD<D>>()`,
it is implementation-defined whether `a[i] * b[i] * 2` is first saturated
to `TFromD<DW>` prior to the addition of `a[i] * b[i] * 2` to `sum[i]`.

* `V`: `{bf,u,i}16`, `D`: `RepartitionToWide<DFromV<V>>`, `VW`: `Vec<D>` \
<code>VW **ReorderWidenMulAccumulate**(D d, V a, V b, VW sum0, VW&
sum1)</code>: widens `a` and `b` to `TFromD<D>`, then adds `a[i] * b[i]` to
Expand Down
29 changes: 29 additions & 0 deletions hwy/ops/arm_neon-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6461,6 +6461,35 @@ HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
return detail::SlideDownLanes(v, amt);
}

// ------------------------------ SatWidenMulAccumFixedPoint

#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#else
#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#endif

template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 16)>
HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 /*di32*/,
VFromD<Rebind<int16_t, DI32>> a,
VFromD<Rebind<int16_t, DI32>> b,
VFromD<DI32> sum) {
return VFromD<DI32>(vqdmlal_s16(sum.raw, a.raw, b.raw));
}

template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 8)>
HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,
VFromD<Rebind<int16_t, DI32>> a,
VFromD<Rebind<int16_t, DI32>> b,
VFromD<DI32> sum) {
const Full128<TFromD<DI32>> di32_full;
const Rebind<int16_t, decltype(di32_full)> di16_full64;
return ResizeBitCast(
di32, SatWidenMulAccumFixedPoint(di32_full, ResizeBitCast(di16_full64, a),
ResizeBitCast(di16_full64, b),
ResizeBitCast(di32_full, sum)));
}

// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)

#if HWY_NEON_HAVE_F32_TO_BF16C
Expand Down
21 changes: 21 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5695,6 +5695,27 @@ HWY_API svuint32_t WidenMulPairwiseAdd(Simd<uint32_t, N, kPow2> d32,
#endif
}

// ------------------------------ SatWidenMulAccumFixedPoint

#if HWY_SVE_HAVE_2

#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#else
#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#endif

template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 /*di32*/,
VFromD<Rebind<int16_t, DI32>> a,
VFromD<Rebind<int16_t, DI32>> b,
VFromD<DI32> sum) {
return svqdmlalb_s32(sum, detail::ZipLowerSame(a, a),
detail::ZipLowerSame(b, b));
}

#endif // HWY_SVE_HAVE_2

// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)

template <size_t N, int kPow2>
Expand Down
60 changes: 60 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4969,6 +4969,66 @@ HWY_API Vec<DI16> SatWidenMulPairwiseAdd(DI16 di16, VU8 a, VI8 b) {

#endif

// ------------------------------ SatWidenMulPairwiseAccumulate

#if (defined(HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM) == \
defined(HWY_TARGET_TOGGLE))

#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#else
#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#endif

template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate(
DI32 di32, VFromD<Repartition<int16_t, DI32>> a,
VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) {
// WidenMulPairwiseAdd(di32, a, b) is okay here as
// a[0]*b[0]+a[1]*b[1] is between -2147418112 and 2147483648 and as
// a[0]*b[0]+a[1]*b[1] can only overflow an int32_t if
// a[0], b[0], a[1], and b[1] are all equal to -32768.

const auto product = WidenMulPairwiseAdd(di32, a, b);

const auto mul_overflow =
VecFromMask(di32, Eq(product, Set(di32, LimitsMin<int32_t>())));

return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)),
Add(product, mul_overflow));
}

#endif // HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM

// ------------------------------ SatWidenMulAccumFixedPoint

#if (defined(HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT) == \
defined(HWY_TARGET_TOGGLE))

#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#else
#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#endif

template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,
VFromD<Rebind<int16_t, DI32>> a,
VFromD<Rebind<int16_t, DI32>> b,
VFromD<DI32> sum) {
const Repartition<int16_t, DI32> dt_i16;

const auto vt_a = ResizeBitCast(dt_i16, a);
const auto vt_b = ResizeBitCast(dt_i16, b);

const auto dup_a = InterleaveWholeLower(dt_i16, vt_a, vt_a);
const auto dup_b = InterleaveWholeLower(dt_i16, vt_b, vt_b);

return SatWidenMulPairwiseAccumulate(di32, dup_a, dup_b, sum);
}

#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT

// ------------------------------ SumOfMulQuadAccumulate

#if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \
Expand Down
18 changes: 18 additions & 0 deletions hwy/ops/ppc_vsx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3518,6 +3518,24 @@ HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
return Add(sum0, sum1);
}

// ------------------------------ SatWidenMulPairwiseAccumulate
#if !HWY_S390X_HAVE_Z14

#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#else
#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#endif

template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 16)>
HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate(
DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a,
VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) {
return VFromD<DI32>{vec_msums(a.raw, b.raw, sum.raw)};
}

#endif // !HWY_S390X_HAVE_Z14

// ------------------------------ SumOfMulQuadAccumulate
#if !HWY_S390X_HAVE_Z14

Expand Down
29 changes: 29 additions & 0 deletions hwy/ops/scalar-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1931,6 +1931,35 @@ HWY_API Vec1<int32_t> WidenMulPairwiseAdd(D32 /* tag */, Vec1<int16_t> a,
return Vec1<int32_t>(a.raw * b.raw);
}

// ------------------------------ SatWidenMulAccumFixedPoint
#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#else
#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT
#endif

template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,
VFromD<Rebind<int16_t, DI32>> a,
VFromD<Rebind<int16_t, DI32>> b,
VFromD<DI32> sum) {
// Multiplying static_cast<int32_t>(a.raw) by static_cast<int32_t>(b.raw)
// followed by an addition of the product is okay as
// (a.raw * b.raw * 2) is between -2147418112 and 2147483648 and as
// a.raw * b.raw * 2 can only overflow an int32_t if both a.raw and b.raw are
// equal to -32768.

const VFromD<DI32> product(static_cast<int32_t>(a.raw) *
static_cast<int32_t>(b.raw));
const VFromD<DI32> product2 = Add(product, product);

const auto mul_overflow =
VecFromMask(di32, Eq(product2, Set(di32, LimitsMin<int32_t>())));

return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)),
Add(product2, mul_overflow));
}

// ------------------------------ SatWidenMulPairwiseAdd

#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD
Expand Down
21 changes: 21 additions & 0 deletions hwy/ops/x86_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9203,6 +9203,27 @@ HWY_API VFromD<DI16> SatWidenMulPairwiseAdd(

#endif

// ------------------------------ SatWidenMulPairwiseAccumulate

#if HWY_TARGET <= HWY_AVX3_DL

#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#else
#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM
#endif

// Even if N=1, the I16 vectors have at least 2 lanes, hence _mm_dpwssds_epi32
// is safe.
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 16)>
HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate(
DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a,
VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) {
return VFromD<DI32>{_mm_dpwssds_epi32(sum.raw, a.raw, b.raw)};
}

#endif // HWY_TARGET <= HWY_AVX3_DL

// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ShiftLeft)

// Generic for all vector lengths.
Expand Down
11 changes: 11 additions & 0 deletions hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6146,6 +6146,17 @@ HWY_API VFromD<DI16> SatWidenMulPairwiseAdd(
return VFromD<DI16>{_mm256_maddubs_epi16(a.raw, b.raw)};
}

// ------------------------------ SatWidenMulPairwiseAccumulate

#if HWY_TARGET <= HWY_AVX3_DL
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 32)>
HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate(
DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a,
VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) {
return VFromD<DI32>{_mm256_dpwssds_epi32(sum.raw, a.raw, b.raw)};
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// ------------------------------ ReorderWidenMulAccumulate
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec256<int16_t> a,
Expand Down
10 changes: 10 additions & 0 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7361,6 +7361,16 @@ HWY_API VFromD<DI16> SatWidenMulPairwiseAdd(
return VFromD<DI16>{_mm512_maddubs_epi16(a.raw, b.raw)};
}

// ------------------------------ SatWidenMulPairwiseAccumulate
#if HWY_TARGET <= HWY_AVX3_DL
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 64)>
HWY_API VFromD<DI32> SatWidenMulPairwiseAccumulate(
DI32 /* tag */, VFromD<Repartition<int16_t, DI32>> a,
VFromD<Repartition<int16_t, DI32>> b, VFromD<DI32> sum) {
return VFromD<DI32>{_mm512_dpwssds_epi32(sum.raw, a.raw, b.raw)};
}
#endif // HWY_TARGET <= HWY_AVX3_DL

// ------------------------------ ReorderWidenMulAccumulate
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)>
HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec512<int16_t> a,
Expand Down
Loading
Loading