From 6bf0a5cb5034a7e684dcc3500e841785237ce2dd Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 19:32:43 +0200 Subject: Adding upstream version 1:115.7.0. Signed-off-by: Daniel Baumann --- third_party/highway/hwy/ops/scalar-inl.h | 1626 ++++++++++++++++++++++++++++++ 1 file changed, 1626 insertions(+) create mode 100644 third_party/highway/hwy/ops/scalar-inl.h (limited to 'third_party/highway/hwy/ops/scalar-inl.h') diff --git a/third_party/highway/hwy/ops/scalar-inl.h b/third_party/highway/hwy/ops/scalar-inl.h new file mode 100644 index 0000000000..c28f7b510f --- /dev/null +++ b/third_party/highway/hwy/ops/scalar-inl.h @@ -0,0 +1,1626 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template +using Sisd = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec1 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 1; // only for DFromV + + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template +class Mask1 { + using Raw = hwy::MakeUnsigned; + + public: + static HWY_INLINE Mask1 FromBool(bool b) { + Mask1 mask; + mask.bits = b ? static_cast(~Raw{0}) : 0; + return mask; + } + + Raw bits; +}; + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template +HWY_API Vec1 BitCast(Sisd /* tag */, Vec1 v) { + static_assert(sizeof(T) <= sizeof(FromT), "Promoting is undefined"); + T to; + CopyBytes(&v.raw, &to); // not same size - ok to shrink + return Vec1(to); +} + +// ------------------------------ Set + +template +HWY_API Vec1 Zero(Sisd /* tag */) { + return Vec1(T(0)); +} + +template +HWY_API Vec1 Set(Sisd /* tag */, const T2 t) { + return Vec1(static_cast(t)); +} + +template +HWY_API Vec1 Undefined(Sisd d) { + return Zero(d); +} + +template +HWY_API Vec1 Iota(const Sisd /* tag */, const T2 first) { + return Vec1(static_cast(first)); +} + +template +using VFromD = decltype(Zero(D())); + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec1 Not(const Vec1 v) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, v).raw))); +} + +// ------------------------------ And + +template +HWY_API Vec1 And(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator&(const Vec1 a, const Vec1 b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template +HWY_API Vec1 AndNot(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, a).raw & + BitCast(du, b).raw))); +} + +// ------------------------------ Or + +template +HWY_API Vec1 Or(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator|(const Vec1 a, const Vec1 b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template +HWY_API Vec1 Xor(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 + +template +HWY_API Vec1 Xor3(Vec1 x1, Vec1 x2, Vec1 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec1 Or3(Vec1 o1, Vec1 o2, Vec1 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ CopySign + +template +HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Sisd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Sisd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec1 BroadcastSignBit(const Vec1 v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + return v.raw < 0 ? Vec1(T(-1)) : Vec1(0); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template +HWY_API Vec1 PopulationCount(Vec1 v) { + return Vec1(static_cast(PopCount(v.raw))); +} + +// ------------------------------ Mask + +template +HWY_API Mask1 RebindMask(Sisd /*tag*/, Mask1 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1{m.bits}; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask1 MaskFromVec(const Vec1 v) { + Mask1 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +Vec1 VecFromMask(const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +Vec1 VecFromMask(Sisd /* tag */, const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +HWY_API Mask1 FirstN(Sisd /*tag*/, size_t n) { + return Mask1::FromBool(n != 0); +} + +// Returns mask ? yes : no. +template +HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, + const Vec1 no) { + return mask.bits ? yes : no; +} + +template +HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { + return mask.bits ? yes : Vec1(0); +} + +template +HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { + return mask.bits ? Vec1(0) : no; +} + +template +HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { + return v.raw < 0 ? yes : no; +} + +template +HWY_API Vec1 ZeroIfNegative(const Vec1 v) { + return v.raw < 0 ? Vec1(0) : v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask1 Not(const Mask1 m) { + return MaskFromVec(Not(VecFromMask(Sisd(), m))); +} + +template +HWY_API Mask1 And(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Or(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 ExclusiveNeither(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeft(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1( + static_cast(static_cast>(v.raw) << kBits)); +} + +template +HWY_API Vec1 ShiftRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(static_cast(v.raw >> kBits)); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = static_cast(BitCast(du, v).raw >> kBits); + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast(sign << sign_shift); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { // T is unsigned + return Vec1(static_cast(v.raw >> kBits)); + } +#endif +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template +struct RotateRight { + template + HWY_INLINE Vec1 operator()(const Vec1 v) const { + return Or(ShiftRight(v), ShiftLeft(v)); + } +}; + +template <> +struct RotateRight<0> { + template + HWY_INLINE Vec1 operator()(const Vec1 v) const { + return v; + } +}; + +} // namespace detail + +template +HWY_API Vec1 RotateRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight()(v); +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { + return Vec1( + static_cast(static_cast>(v.raw) << bits)); +} + +template +HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(static_cast(v.raw >> bits)); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = static_cast(BitCast(du, v).raw >> bits); + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast(sign << sign_shift); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { // T is unsigned + return Vec1(static_cast(v.raw >> bits)); + } +#endif +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template +HWY_API Vec1 operator<<(const Vec1 v, const Vec1 bits) { + return ShiftLeftSame(v, static_cast(bits.raw)); +} + +template +HWY_API Vec1 operator>>(const Vec1 v, const Vec1 bits) { + return ShiftRightSame(v, static_cast(bits.raw)); +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec1 operator+(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} + +template +HWY_API Vec1 operator-(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +HWY_API Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} +HWY_API Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} + +// ------------------------------ Absolute value + +template +HWY_API Vec1 Abs(const Vec1 a) { + const T i = a.raw; + if (i >= 0 || i == hwy::LimitsMin()) return a; + return Vec1(static_cast(-i & T{-1})); +} +HWY_API Vec1 Abs(Vec1 a) { + int32_t i; + CopyBytes(&a.raw, &i); + i &= 0x7FFFFFFF; + CopyBytes(&i, &a.raw); + return a; +} +HWY_API Vec1 Abs(Vec1 a) { + int64_t i; + CopyBytes(&a.raw, &i); + i &= 0x7FFFFFFFFFFFFFFFL; + CopyBytes(&i, &a.raw); + return a; +} + +// ------------------------------ Min/Max + +// may be unavailable, so implement our own. +namespace detail { + +static inline float Abs(float f) { + uint32_t i; + CopyBytes<4>(&f, &i); + i &= 0x7FFFFFFFu; + CopyBytes<4>(&i, &f); + return f; +} +static inline double Abs(double f) { + uint64_t i; + CopyBytes<8>(&f, &i); + i &= 0x7FFFFFFFFFFFFFFFull; + CopyBytes<8>(&i, &f); + return f; +} + +static inline bool SignBit(float f) { + uint32_t i; + CopyBytes<4>(&f, &i); + return (i >> 31) != 0; +} +static inline bool SignBit(double f) { + uint64_t i; + CopyBytes<8>(&f, &i); + return (i >> 63) != 0; +} + +} // namespace detail + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + if (isnan(a.raw)) return b; + if (isnan(b.raw)) return a; + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + if (isnan(a.raw)) return b; + if (isnan(b.raw)) return a; + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Xor(v, SignBit(Sisd())); +} + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Zero(Sisd()) - v; +} + +// ------------------------------ mul/div + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(double{a.raw} * b.raw)); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(static_cast(a.raw) * + static_cast(b.raw))); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(static_cast(a.raw) * + static_cast(b.raw))); +} + +template +HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { + return Vec1(a.raw / b.raw); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((a.raw * b.raw) >> 16)); +} +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the result + // is the same but this way it is also defined. + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> 16)); +} + +HWY_API Vec1 MulFixedPoint15(Vec1 a, Vec1 b) { + return Vec1(static_cast((2 * a.raw * b.raw + 32768) >> 16)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { + const int64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} +HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { + const uint64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} + +// Approximate reciprocal +HWY_API Vec1 ApproximateReciprocal(const Vec1 v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1(0.0f); + return Vec1(1.0f / v.raw); +} + +// Absolute value of difference. +HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { + return mul * x + add; +} + +template +HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return add - mul * x; +} + +template +HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { + return mul * x - sub; +} + +template +HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_API Vec1 ApproximateReciprocalSqrt(const Vec1 v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopySameSize(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &f); + // One Newton-Raphson iteration + return Vec1(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_API Vec1 Sqrt(const Vec1 v) { +#if HWY_COMPILER_GCC && defined(HWY_NO_LIBCXX) + return Vec1(__builtin_sqrt(v.raw)); +#else + return Vec1(sqrtf(v.raw)); +#endif +} +HWY_API Vec1 Sqrt(const Vec1 v) { +#if HWY_COMPILER_GCC && defined(HWY_NO_LIBCXX) + return Vec1(__builtin_sqrt(v.raw)); +#else + return Vec1(sqrt(v.raw)); +#endif +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec1 Round(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN + return v; + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(0), v); + // Round to even + if ((rounded & 1) && detail::Abs(static_cast(rounded) - v.raw) == T(0.5)) { + return Vec1(static_cast(rounded - (v.raw < T(0) ? -1 : 1))); + } + return Vec1(static_cast(rounded)); +} + +// Round-to-nearest even. +HWY_API Vec1 NearestInt(const Vec1 v) { + using T = float; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool is_sign = detail::SignBit(v.raw); + + if (!(abs < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= static_cast(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + return Vec1(static_cast(v.raw)); + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return Vec1(0); + // Round to even + if ((rounded & 1) && detail::Abs(static_cast(rounded) - v.raw) == T(0.5)) { + return Vec1(rounded - (is_sign ? -1 : 1)); + } + return Vec1(rounded); +} + +template +HWY_API Vec1 Trunc(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN + return v; + } + const TI truncated = static_cast(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1(0), v); + return Vec1(static_cast(truncated)); +} + +template +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +template +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} + +// Toward -infinity, aka floor +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} + +// ================================================== COMPARE + +template +HWY_API Mask1 operator==(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw == b.raw); +} + +template +HWY_API Mask1 operator!=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw != b.raw); +} + +template +HWY_API Mask1 TestBit(const Vec1 v, const Vec1 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask1 operator<(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw < b.raw); +} +template +HWY_API Mask1 operator>(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw > b.raw); +} + +template +HWY_API Mask1 operator<=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw <= b.raw); +} +template +HWY_API Mask1 operator>=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw >= b.raw); +} + +// ------------------------------ Floating-point classification (==) + +template +HWY_API Mask1 IsNaN(const Vec1 v) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned bits; + CopySameSize(&v, &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + return Mask1::FromBool(bits > ExponentMask()); +} + +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFF000000u)); +} +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFFE0000000000000ull)); +} + +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFF000000u); +} +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFFE0000000000000ull); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec1 Load(Sisd /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopySameSize(aligned, &t); + return Vec1(t); +} + +template +HWY_API Vec1 MaskedLoad(Mask1 m, Sisd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API Vec1 LoadU(Sisd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API Vec1 LoadDup128(Sisd d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec1 v, Sisd /* tag */, + T* HWY_RESTRICT aligned) { + CopySameSize(&v.raw, aligned); +} + +template +HWY_API void StoreU(const Vec1 v, Sisd d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +template +HWY_API void BlendedStore(const Vec1 v, Mask1 m, Sisd d, + T* HWY_RESTRICT p) { + if (!m.bits) return; + StoreU(v, d, p); +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +template +HWY_API void LoadInterleaved3(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1, Vec1& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +template +HWY_API void LoadInterleaved4(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1, Vec1& v2, + Vec1& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(const Vec1 v0, const Vec1 v1, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); +} + +template +HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, + const Vec1 v2, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +template +HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, + const Vec1 v2, const Vec1 v3, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template +HWY_API void Stream(const Vec1 v, Sisd d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template +HWY_API void ScatterOffset(Vec1 v, Sisd d, T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + uint8_t* const base8 = reinterpret_cast(base) + offset.raw; + return Store(v, d, reinterpret_cast(base8)); +} + +template +HWY_API void ScatterIndex(Vec1 v, Sisd d, T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +template +HWY_API Vec1 GatherOffset(Sisd d, const T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + return Load(d, reinterpret_cast(addr)); +} + +template +HWY_API Vec1 GatherIndex(Sisd d, const T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Load(d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template +HWY_API Vec1 PromoteTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1(static_cast(from.raw)); +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting float to narrower integer/float + if (IsInf(from).bits || + Abs(from).raw > static_cast(HighestValue())) { + return Vec1(detail::SignBit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (IsInf(from).bits || + Abs(from).raw > static_cast(HighestValue())) { + return Vec1(detail::SignBit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + static_assert(!IsFloat(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw = HWY_MIN(HWY_MAX(LimitsMin(), from.raw), LimitsMax()); + return Vec1(static_cast(from.raw)); +} + +HWY_API Vec1 PromoteTo(Sisd /* tag */, const Vec1 v) { + uint16_t bits16; + CopySameSize(&v.raw, &bits16); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return Vec1(sign ? -subnormal : subnormal); + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + float out; + CopySameSize(&bits32, &out); + return Vec1(out); +} + +HWY_API Vec1 PromoteTo(Sisd d, const Vec1 v) { + return Set(d, F32FromBF16(v.raw)); +} + +HWY_API Vec1 DemoteTo(Sisd /* tag */, + const Vec1 v) { + uint32_t bits32; + CopySameSize(&v.raw, &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + Vec1 out; + if (exp < -24) { + const uint16_t zero = 0; + CopySameSize(&zero, &out.raw); + return out; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + CopySameSize(&narrowed, &out.raw); + return out; +} + +HWY_API Vec1 DemoteTo(Sisd d, const Vec1 v) { + return Set(d, BF16FromF32(v.raw)); +} + +template +HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax in FromT, so use double. + const double f = static_cast(from.raw); + if (IsInf(from).bits || + Abs(Vec1(f)).raw > static_cast(LimitsMax())) { + return Vec1(detail::SignBit(from.raw) ? LimitsMin() + : LimitsMax()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // int## -> float##: no check needed + return Vec1(static_cast(from.raw)); +} + +HWY_API Vec1 U8FromU32(const Vec1 v) { + return DemoteTo(Sisd(), v); +} + +// ------------------------------ Truncations + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFFFFFFu)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +// ================================================== COMBINE +// UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. + +template +HWY_API Vec1 LowerHalf(Vec1 v) { + return v; +} + +template +HWY_API Vec1 LowerHalf(Sisd /* tag */, Vec1 v) { + return v; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec1 v) { + return v.raw; +} + +template +HWY_API T ExtractLane(const Vec1 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return v.raw; +} + +template +HWY_API Vec1 InsertLane(Vec1 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + v.raw = t; + return v; +} + +template +HWY_API Vec1 DupEven(Vec1 v) { + return v; +} +// DupOdd is unsupported. + +template +HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { + return even; +} + +template +HWY_API Vec1 OddEvenBlocks(Vec1 /* odd */, Vec1 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec1 SwapAdjacentBlocks(Vec1 v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices1 { + MakeSigned raw; +}; + +template +HWY_API Indices1 IndicesFromVec(Sisd, Vec1 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + HWY_DASSERT(vec.raw == 0); + return Indices1{vec.raw}; +} + +template +HWY_API Indices1 SetTableIndices(Sisd d, const TI* idx) { + return IndicesFromVec(d, LoadU(Sisd(), idx)); +} + +template +HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec1 ReverseBlocks(Sisd /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ Reverse + +template +HWY_API Vec1 Reverse(Sisd /* tag */, const Vec1 v) { + return v; +} + +// Must not be called: +template +HWY_API Vec1 Reverse2(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse4(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse8(Sisd /* tag */, const Vec1 v) { + return v; +} + +// ================================================== BLOCKWISE +// Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec1 Broadcast(const Vec1 v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +template +HWY_API Vec1 TableLookupBytesOr0(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +// ------------------------------ ZipLower + +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((uint32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint64_t{b.raw} << 32) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((int32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1((int32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1((int64_t{b.raw} << 32) + a.raw); +} + +template , class VW = Vec1> +HWY_API VW ZipLower(Sisd /* tag */, Vec1 a, Vec1 b) { + return VW(static_cast((TW{b.raw} << (sizeof(T) * 8)) + a.raw)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0; +} + +template +HWY_API bool AllTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits != 0; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask1 LoadMaskBits(Sisd /* tag */, + const uint8_t* HWY_RESTRICT bits) { + return Mask1::FromBool((bits[0] & 1) != 0); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Sisd d, const Mask1 mask, uint8_t* bits) { + *bits = AllTrue(d, mask); + return 1; +} + +template +HWY_API size_t CountTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0 ? 0 : 1; +} + +template +HWY_API intptr_t FindFirstTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template +HWY_API size_t FindKnownFirstTrue(Sisd /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +// ------------------------------ Compress, CompressBits + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +template +HWY_API Vec1 CompressNot(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT unaligned) { + if (!mask.bits) return 0; + StoreU(v, d, unaligned); + return 1; +} + +// ------------------------------ CompressBits +template +HWY_API Vec1 CompressBits(Vec1 v, const uint8_t* HWY_RESTRICT /*bits*/) { + return v; +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(Vec1 v, const uint8_t* HWY_RESTRICT bits, + Sisd d, T* HWY_RESTRICT unaligned) { + const Mask1 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return MulAdd(Vec1(F32FromBF16(a.raw)), + Vec1(F32FromBF16(b.raw)), sum0); +} + +HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(a.raw * b.raw + sum0.raw); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec1 RearrangeToOddPlusEven(const Vec1 sum0, + Vec1 /* sum1 */) { + return sum0; // invariant already holds +} + +// ================================================== REDUCTIONS + +// Sum of all lanes, i.e. the only one. +template +HWY_API Vec1 SumOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} +template +HWY_API Vec1 MinOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} +template +HWY_API Vec1 MaxOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); -- cgit v1.2.3