// Copyright 2022 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Single-element vectors and operations. // External include guard in highway.h - see comment there. #include #include #include // std::abs, std::isnan #include "hwy/base.h" #include "hwy/ops/shared-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template using Full128 = Simd; // (Wrapper class required for overloading comparison operators.) template struct Vec128 { using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = N; // only for DFromV HWY_INLINE Vec128() = default; Vec128(const Vec128&) = default; Vec128& operator=(const Vec128&) = default; HWY_INLINE Vec128& operator*=(const Vec128 other) { return *this = (*this * other); } HWY_INLINE Vec128& operator/=(const Vec128 other) { return *this = (*this / other); } HWY_INLINE Vec128& operator+=(const Vec128 other) { return *this = (*this + other); } HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } HWY_INLINE Vec128& operator|=(const Vec128 other) { return *this = (*this | other); } HWY_INLINE Vec128& operator^=(const Vec128 other) { return *this = (*this ^ other); } // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h // relies on this for LoadInterleaved*. CAVEAT: this method of padding // prevents using range for, especially in SumOfLanes, where it would be // incorrect. Moving padding to another field would require handling the case // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. T raw[16 / sizeof(T)] = {}; }; // 0 or FF..FF, same size as Vec128. template struct Mask128 { using Raw = hwy::MakeUnsigned; static HWY_INLINE Raw FromBool(bool b) { return b ? static_cast(~Raw{0}) : 0; } // Must match the size of Vec128. Raw bits[16 / sizeof(T)] = {}; }; template using DFromV = Simd; template using TFromV = typename V::PrivateT; // ------------------------------ BitCast template HWY_API Vec128 BitCast(Simd /* tag */, Vec128 v) { Vec128 to; CopySameSize(&v, &to); return to; } // ------------------------------ Set template HWY_API Vec128 Zero(Simd /* tag */) { Vec128 v; ZeroBytes(v.raw); return v; } template using VFromD = decltype(Zero(D())); template HWY_API Vec128 Set(Simd /* tag */, const T2 t) { Vec128 v; for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(t); } return v; } template HWY_API Vec128 Undefined(Simd d) { return Zero(d); } template HWY_API Vec128 Iota(const Simd /* tag */, T2 first) { Vec128 v; for (size_t i = 0; i < N; ++i) { v.raw[i] = AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); } return v; } // ================================================== LOGICAL // ------------------------------ Not template HWY_API Vec128 Not(const Vec128 v) { const Simd d; const RebindToUnsigned du; using TU = TFromD; VFromD vu = BitCast(du, v); for (size_t i = 0; i < N; ++i) { vu.raw[i] = static_cast(~vu.raw[i]); } return BitCast(d, vu); } // ------------------------------ And template HWY_API Vec128 And(const Vec128 a, const Vec128 b) { const Simd d; const RebindToUnsigned du; auto au = BitCast(du, a); auto bu = BitCast(du, b); for (size_t i = 0; i < N; ++i) { au.raw[i] &= bu.raw[i]; } return BitCast(d, au); } template HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { return And(a, b); } // ------------------------------ AndNot template HWY_API Vec128 AndNot(const Vec128 a, const Vec128 b) { return And(Not(a), b); } // ------------------------------ Or template HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { const Simd d; const RebindToUnsigned du; auto au = BitCast(du, a); auto bu = BitCast(du, b); for (size_t i = 0; i < N; ++i) { au.raw[i] |= bu.raw[i]; } return BitCast(d, au); } template HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { return Or(a, b); } // ------------------------------ Xor template HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { const Simd d; const RebindToUnsigned du; auto au = BitCast(du, a); auto bu = BitCast(du, b); for (size_t i = 0; i < N; ++i) { au.raw[i] ^= bu.raw[i]; } return BitCast(d, au); } template HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { return Xor(a, b); } // ------------------------------ Xor3 template HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { return Xor(x1, Xor(x2, x3)); } // ------------------------------ Or3 template HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { return Or(o1, Or(o2, o3)); } // ------------------------------ OrAnd template HWY_API Vec128 OrAnd(const Vec128 o, const Vec128 a1, const Vec128 a2) { return Or(o, And(a1, a2)); } // ------------------------------ IfVecThenElse template HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, Vec128 no) { return Or(And(mask, yes), AndNot(mask, no)); } // ------------------------------ CopySign template HWY_API Vec128 CopySign(const Vec128 magn, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const auto msb = SignBit(Simd()); return Or(AndNot(msb, magn), And(msb, sign)); } template HWY_API Vec128 CopySignToAbs(const Vec128 abs, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); return Or(abs, And(SignBit(Simd()), sign)); } // ------------------------------ BroadcastSignBit template HWY_API Vec128 BroadcastSignBit(Vec128 v) { // This is used inside ShiftRight, so we cannot implement in terms of it. for (size_t i = 0; i < N; ++i) { v.raw[i] = v.raw[i] < 0 ? T(-1) : T(0); } return v; } // ------------------------------ Mask template HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 mask) { Mask128 to; CopySameSize(&mask, &to); return to; } // v must be 0 or FF..FF. template HWY_API Mask128 MaskFromVec(const Vec128 v) { Mask128 mask; CopySameSize(&v, &mask); return mask; } template Vec128 VecFromMask(const Mask128 mask) { Vec128 v; CopySameSize(&mask, &v); return v; } template Vec128 VecFromMask(Simd /* tag */, const Mask128 mask) { return VecFromMask(mask); } template HWY_API Mask128 FirstN(Simd /*tag*/, size_t n) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(i < n); } return m; } // Returns mask ? yes : no. template HWY_API Vec128 IfThenElse(const Mask128 mask, const Vec128 yes, const Vec128 no) { return IfVecThenElse(VecFromMask(mask), yes, no); } template HWY_API Vec128 IfThenElseZero(const Mask128 mask, const Vec128 yes) { return IfVecThenElse(VecFromMask(mask), yes, Zero(Simd())); } template HWY_API Vec128 IfThenZeroElse(const Mask128 mask, const Vec128 no) { return IfVecThenElse(VecFromMask(mask), Zero(Simd()), no); } template HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, Vec128 no) { for (size_t i = 0; i < N; ++i) { v.raw[i] = v.raw[i] < 0 ? yes.raw[i] : no.raw[i]; } return v; } template HWY_API Vec128 ZeroIfNegative(const Vec128 v) { return IfNegativeThenElse(v, Zero(Simd()), v); } // ------------------------------ Mask logical template HWY_API Mask128 Not(const Mask128 m) { return MaskFromVec(Not(VecFromMask(Simd(), m))); } template HWY_API Mask128 And(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Or(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } // ================================================== SHIFTS // ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) template HWY_API Vec128 ShiftLeft(Vec128 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); for (size_t i = 0; i < N; ++i) { const auto shifted = static_cast>(v.raw[i]) << kBits; v.raw[i] = static_cast(shifted); } return v; } template HWY_API Vec128 ShiftRight(Vec128 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> kBits); } #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU shifted = static_cast(static_cast(v.raw[i]) >> kBits); const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; const size_t sign_shift = static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); const TU upper = static_cast(sign << sign_shift); v.raw[i] = static_cast(shifted | upper); } } else { // T is unsigned for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> kBits); } } #endif return v; } // ------------------------------ RotateRight (ShiftRight) namespace detail { // For partial specialization: kBits == 0 results in an invalid shift count template struct RotateRight { template HWY_INLINE Vec128 operator()(const Vec128 v) const { return Or(ShiftRight(v), ShiftLeft(v)); } }; template <> struct RotateRight<0> { template HWY_INLINE Vec128 operator()(const Vec128 v) const { return v; } }; } // namespace detail template HWY_API Vec128 RotateRight(const Vec128 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); return detail::RotateRight()(v); } // ------------------------------ ShiftLeftSame template HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { for (size_t i = 0; i < N; ++i) { const auto shifted = static_cast>(v.raw[i]) << bits; v.raw[i] = static_cast(shifted); } return v; } template HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits); } #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU shifted = static_cast(static_cast(v.raw[i]) >> bits); const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; const size_t sign_shift = static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); const TU upper = static_cast(sign << sign_shift); v.raw[i] = static_cast(shifted | upper); } } else { for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits); // unsigned, logical shift } } #endif return v; } // ------------------------------ Shl template HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { for (size_t i = 0; i < N; ++i) { const auto shifted = static_cast>(v.raw[i]) << bits.raw[i]; v.raw[i] = static_cast(shifted); } return v; } template HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); } #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU shifted = static_cast(static_cast(v.raw[i]) >> bits.raw[i]); const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; const size_t sign_shift = static_cast( static_cast(sizeof(TU)) * 8 - 1 - bits.raw[i]); const TU upper = static_cast(sign << sign_shift); v.raw[i] = static_cast(shifted | upper); } } else { // T is unsigned for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); } } #endif return v; } // ================================================== ARITHMETIC // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Add(hwy::NonFloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { const uint64_t a64 = static_cast(a.raw[i]); const uint64_t b64 = static_cast(b.raw[i]); a.raw[i] = static_cast((a64 + b64) & static_cast(~T(0))); } return a; } template HWY_INLINE Vec128 Sub(hwy::NonFloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { const uint64_t a64 = static_cast(a.raw[i]); const uint64_t b64 = static_cast(b.raw[i]); a.raw[i] = static_cast((a64 - b64) & static_cast(~T(0))); } return a; } template HWY_INLINE Vec128 Add(hwy::FloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] += b.raw[i]; } return a; } template HWY_INLINE Vec128 Sub(hwy::FloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] -= b.raw[i]; } return a; } } // namespace detail template HWY_API Vec128 operator-(Vec128 a, const Vec128 b) { return detail::Sub(hwy::IsFloatTag(), a, b); } template HWY_API Vec128 operator+(Vec128 a, const Vec128 b) { return detail::Add(hwy::IsFloatTag(), a, b); } // ------------------------------ SumsOf8 template HWY_API Vec128 SumsOf8(const Vec128 v) { Vec128 sums; for (size_t i = 0; i < N; ++i) { sums.raw[i / 8] += v.raw[i]; } return sums; } // ------------------------------ SaturatedAdd template HWY_API Vec128 SaturatedAdd(Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast( HWY_MIN(HWY_MAX(hwy::LowestValue(), a.raw[i] + b.raw[i]), hwy::HighestValue())); } return a; } // ------------------------------ SaturatedSub template HWY_API Vec128 SaturatedSub(Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast( HWY_MIN(HWY_MAX(hwy::LowestValue(), a.raw[i] - b.raw[i]), hwy::HighestValue())); } return a; } // ------------------------------ AverageRound template HWY_API Vec128 AverageRound(Vec128 a, const Vec128 b) { static_assert(!IsSigned(), "Only for unsigned"); for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast((a.raw[i] + b.raw[i] + 1) / 2); } return a; } // ------------------------------ Abs // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Abs(SignedTag /*tag*/, Vec128 a) { for (size_t i = 0; i < N; ++i) { const T s = a.raw[i]; const T min = hwy::LimitsMin(); a.raw[i] = static_cast((s >= 0 || s == min) ? a.raw[i] : -s); } return a; } template HWY_INLINE Vec128 Abs(hwy::FloatTag /*tag*/, Vec128 v) { for (size_t i = 0; i < N; ++i) { v.raw[i] = std::abs(v.raw[i]); } return v; } } // namespace detail template HWY_API Vec128 Abs(Vec128 a) { return detail::Abs(hwy::TypeTag(), a); } // ------------------------------ Min/Max // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Min(hwy::NonFloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); } return a; } template HWY_INLINE Vec128 Max(hwy::NonFloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); } return a; } template HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { if (std::isnan(a.raw[i])) { a.raw[i] = b.raw[i]; } else if (std::isnan(b.raw[i])) { // no change } else { a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); } } return a; } template HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { if (std::isnan(a.raw[i])) { a.raw[i] = b.raw[i]; } else if (std::isnan(b.raw[i])) { // no change } else { a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); } } return a; } } // namespace detail template HWY_API Vec128 Min(Vec128 a, const Vec128 b) { return detail::Min(hwy::IsFloatTag(), a, b); } template HWY_API Vec128 Max(Vec128 a, const Vec128 b) { return detail::Max(hwy::IsFloatTag(), a, b); } // ------------------------------ Neg // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_API Vec128 Neg(hwy::NonFloatTag /*tag*/, Vec128 v) { return Zero(Simd()) - v; } template HWY_API Vec128 Neg(hwy::FloatTag /*tag*/, Vec128 v) { return Xor(v, SignBit(Simd())); } } // namespace detail template HWY_API Vec128 Neg(Vec128 v) { return detail::Neg(hwy::IsFloatTag(), v); } // ------------------------------ Mul/Div // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Mul(hwy::FloatTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] *= b.raw[i]; } return a; } template HWY_INLINE Vec128 Mul(SignedTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast(static_cast(a.raw[i]) * static_cast(b.raw[i])); } return a; } template HWY_INLINE Vec128 Mul(UnsignedTag /*tag*/, Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast(static_cast(a.raw[i]) * static_cast(b.raw[i])); } return a; } } // namespace detail template HWY_API Vec128 operator*(Vec128 a, const Vec128 b) { return detail::Mul(hwy::TypeTag(), a, b); } template HWY_API Vec128 operator/(Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] /= b.raw[i]; } return a; } // Returns the upper 16 bits of a * b in each lane. template HWY_API Vec128 MulHigh(Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast((int32_t{a.raw[i]} * b.raw[i]) >> 16); } return a; } template HWY_API Vec128 MulHigh(Vec128 a, const Vec128 b) { for (size_t i = 0; i < N; ++i) { // Cast to uint32_t first to prevent overflow. Otherwise the result of // uint16_t * uint16_t is in "int" which may overflow. In practice the // result is the same but this way it is also defined. a.raw[i] = static_cast( (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> 16); } return a; } template HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast((2 * a.raw[i] * b.raw[i] + 32768) >> 16); } return a; } // Multiplies even lanes (0, 2 ..) and returns the double-wide result. template HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { Vec128 mul; for (size_t i = 0; i < N; i += 2) { const int64_t a64 = a.raw[i]; mul.raw[i / 2] = a64 * b.raw[i]; } return mul; } template HWY_API Vec128 MulEven(Vec128 a, const Vec128 b) { Vec128 mul; for (size_t i = 0; i < N; i += 2) { const uint64_t a64 = a.raw[i]; mul.raw[i / 2] = a64 * b.raw[i]; } return mul; } template HWY_API Vec128 MulOdd(const Vec128 a, const Vec128 b) { Vec128 mul; for (size_t i = 0; i < N; i += 2) { const int64_t a64 = a.raw[i + 1]; mul.raw[i / 2] = a64 * b.raw[i + 1]; } return mul; } template HWY_API Vec128 MulOdd(Vec128 a, const Vec128 b) { Vec128 mul; for (size_t i = 0; i < N; i += 2) { const uint64_t a64 = a.raw[i + 1]; mul.raw[i / 2] = a64 * b.raw[i + 1]; } return mul; } template HWY_API Vec128 ApproximateReciprocal(Vec128 v) { for (size_t i = 0; i < N; ++i) { // Zero inputs are allowed, but callers are responsible for replacing the // return value with something else (typically using IfThenElse). This check // avoids a ubsan error. The result is arbitrary. v.raw[i] = (std::abs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; } return v; } template HWY_API Vec128 AbsDiff(Vec128 a, const Vec128 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants template HWY_API Vec128 MulAdd(Vec128 mul, const Vec128 x, const Vec128 add) { return mul * x + add; } template HWY_API Vec128 NegMulAdd(Vec128 mul, const Vec128 x, const Vec128 add) { return add - mul * x; } template HWY_API Vec128 MulSub(Vec128 mul, const Vec128 x, const Vec128 sub) { return mul * x - sub; } template HWY_API Vec128 NegMulSub(Vec128 mul, const Vec128 x, const Vec128 sub) { return Neg(mul) * x - sub; } // ------------------------------ Floating-point square root template HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { for (size_t i = 0; i < N; ++i) { const float half = v.raw[i] * 0.5f; uint32_t bits; CopySameSize(&v.raw[i], &bits); // Initial guess based on log2(f) bits = 0x5F3759DF - (bits >> 1); CopySameSize(&bits, &v.raw[i]); // One Newton-Raphson iteration v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); } return v; } template HWY_API Vec128 Sqrt(Vec128 v) { for (size_t i = 0; i < N; ++i) { v.raw[i] = std::sqrt(v.raw[i]); } return v; } // ------------------------------ Floating-point rounding template HWY_API Vec128 Round(Vec128 v) { using TI = MakeSigned; const Vec128 a = Abs(v); for (size_t i = 0; i < N; ++i) { if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN continue; } const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); const TI rounded = static_cast(v.raw[i] + bias); if (rounded == 0) { v.raw[i] = v.raw[i] < 0 ? T{-0} : T{0}; continue; } const T rounded_f = static_cast(rounded); // Round to even if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { v.raw[i] = static_cast(rounded - (v.raw[i] < T(0) ? -1 : 1)); continue; } v.raw[i] = rounded_f; } return v; } // Round-to-nearest even. template HWY_API Vec128 NearestInt(const Vec128 v) { using T = float; using TI = int32_t; const Vec128 abs = Abs(v); Vec128 ret; for (size_t i = 0; i < N; ++i) { const bool signbit = std::signbit(v.raw[i]); if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN // Check if too large to cast or NaN if (!(abs.raw[i] <= static_cast(LimitsMax()))) { ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); continue; } ret.raw[i] = static_cast(v.raw[i]); continue; } const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); const TI rounded = static_cast(v.raw[i] + bias); if (rounded == 0) { ret.raw[i] = 0; continue; } const T rounded_f = static_cast(rounded); // Round to even if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { ret.raw[i] = rounded - (signbit ? -1 : 1); continue; } ret.raw[i] = rounded; } return ret; } template HWY_API Vec128 Trunc(Vec128 v) { using TI = MakeSigned; const Vec128 abs = Abs(v); for (size_t i = 0; i < N; ++i) { if (!(abs.raw[i] <= MantissaEnd())) { // Huge or NaN continue; } const TI truncated = static_cast(v.raw[i]); if (truncated == 0) { v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; continue; } v.raw[i] = static_cast(truncated); } return v; } // Toward +infinity, aka ceiling template Vec128 Ceil(Vec128 v) { constexpr int kMantissaBits = MantissaBits(); using Bits = MakeUnsigned; const Bits kExponentMask = MaxExponentField(); const Bits kMantissaMask = MantissaMask(); const Bits kBias = kExponentMask / 2; for (size_t i = 0; i < N; ++i) { const bool positive = v.raw[i] > Float(0.0); Bits bits; CopySameSize(&v.raw[i], &bits); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); // Already an integer. if (exponent >= kMantissaBits) continue; // |v| <= 1 => 0 or 1. if (exponent < 0) { v.raw[i] = positive ? Float{1} : Float{-0.0}; continue; } const Bits mantissa_mask = kMantissaMask >> exponent; // Already an integer if ((bits & mantissa_mask) == 0) continue; // Clear fractional bits and round up if (positive) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; CopySameSize(&bits, &v.raw[i]); } return v; } // Toward -infinity, aka floor template Vec128 Floor(Vec128 v) { constexpr int kMantissaBits = MantissaBits(); using Bits = MakeUnsigned; const Bits kExponentMask = MaxExponentField(); const Bits kMantissaMask = MantissaMask(); const Bits kBias = kExponentMask / 2; for (size_t i = 0; i < N; ++i) { const bool negative = v.raw[i] < Float(0.0); Bits bits; CopySameSize(&v.raw[i], &bits); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); // Already an integer. if (exponent >= kMantissaBits) continue; // |v| <= 1 => -1 or 0. if (exponent < 0) { v.raw[i] = negative ? Float(-1.0) : Float(0.0); continue; } const Bits mantissa_mask = kMantissaMask >> exponent; // Already an integer if ((bits & mantissa_mask) == 0) continue; // Clear fractional bits and round down if (negative) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; CopySameSize(&bits, &v.raw[i]); } return v; } // ------------------------------ Floating-point classification template HWY_API Mask128 IsNaN(const Vec128 v) { Mask128 ret; for (size_t i = 0; i < N; ++i) { // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. MakeUnsigned bits; CopySameSize(&v.raw[i], &bits); bits += bits; bits >>= 1; // clear sign bit // NaN if all exponent bits are set and the mantissa is not zero. ret.bits[i] = Mask128::FromBool(bits > ExponentMask()); } return ret; } template HWY_API Mask128 IsInf(const Vec128 v) { static_assert(IsFloat(), "Only for float"); const Simd d; const RebindToSigned di; const VFromD vi = BitCast(di, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. template HWY_API Mask128 IsFinite(const Vec128 v) { static_assert(IsFloat(), "Only for float"); const Simd d; const RebindToUnsigned du; const RebindToSigned di; // cheaper than unsigned comparison using VI = VFromD; using VU = VFromD; const VU vu = BitCast(du, v); // 'Shift left' to clear the sign bit, then right so we can compare with the // max exponent (cannot compare with MaxExponentTimes2 directly because it is // negative and non-negative floats would be greater). const VI exp = BitCast(di, ShiftRight() + 1>(Add(vu, vu))); return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); } // ================================================== COMPARE template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] == b.raw[i]); } return m; } template HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] != b.raw[i]); } return m; } template HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } template HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] < b.raw[i]); } return m; } template HWY_API Mask128 operator>(const Vec128 a, const Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] > b.raw[i]); } return m; } template HWY_API Mask128 operator<=(const Vec128 a, const Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] <= b.raw[i]); } return m; } template HWY_API Mask128 operator>=(const Vec128 a, const Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] >= b.raw[i]); } return m; } // ------------------------------ Lt128 // Only makes sense for full vectors of u64. HWY_API Mask128 Lt128(Simd /* tag */, Vec128 a, const Vec128 b) { const bool lt = (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); return ret; } HWY_API Mask128 Lt128Upper(Simd /* tag */, Vec128 a, const Vec128 b) { const bool lt = a.raw[1] < b.raw[1]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); return ret; } // ------------------------------ Eq128 // Only makes sense for full vectors of u64. HWY_API Mask128 Eq128(Simd /* tag */, Vec128 a, const Vec128 b) { const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); return ret; } HWY_API Mask128 Ne128(Simd /* tag */, Vec128 a, const Vec128 b) { const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); return ret; } HWY_API Mask128 Eq128Upper(Simd /* tag */, Vec128 a, const Vec128 b) { const bool eq = a.raw[1] == b.raw[1]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); return ret; } HWY_API Mask128 Ne128Upper(Simd /* tag */, Vec128 a, const Vec128 b) { const bool ne = a.raw[1] != b.raw[1]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); return ret; } // ------------------------------ Min128, Max128 (Lt128) template > HWY_API V Min128(D d, const V a, const V b) { return IfThenElse(Lt128(d, a, b), a, b); } template > HWY_API V Max128(D d, const V a, const V b) { return IfThenElse(Lt128(d, b, a), a, b); } template > HWY_API V Min128Upper(D d, const V a, const V b) { return IfThenElse(Lt128Upper(d, a, b), a, b); } template > HWY_API V Max128Upper(D d, const V a, const V b) { return IfThenElse(Lt128Upper(d, b, a), a, b); } // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT aligned) { Vec128 v; CopyBytes(aligned, v.raw); // copy from array return v; } template HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } template HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { return Load(d, p); } // In some use cases, "load single lane" is sufficient; otherwise avoid this. template HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT aligned) { return Load(d, aligned); } // ------------------------------ Store template HWY_API void Store(const Vec128 v, Simd /* tag */, T* HWY_RESTRICT aligned) { CopyBytes(v.raw, aligned); // copy to array } template HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { Store(v, d, p); } template HWY_API void BlendedStore(const Vec128 v, Mask128 m, Simd /* tag */, T* HWY_RESTRICT p) { for (size_t i = 0; i < N; ++i) { if (m.bits[i]) p[i] = v.raw[i]; } } // ------------------------------ LoadInterleaved2/3/4 // Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. // We implement those here because scalar code is likely faster than emulation // via shuffles. #ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED #undef HWY_NATIVE_LOAD_STORE_INTERLEAVED #else #define HWY_NATIVE_LOAD_STORE_INTERLEAVED #endif template HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1) { alignas(16) T buf0[N]; alignas(16) T buf1[N]; for (size_t i = 0; i < N; ++i) { buf0[i] = *unaligned++; buf1[i] = *unaligned++; } v0 = Load(d, buf0); v1 = Load(d, buf1); } template HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2) { alignas(16) T buf0[N]; alignas(16) T buf1[N]; alignas(16) T buf2[N]; for (size_t i = 0; i < N; ++i) { buf0[i] = *unaligned++; buf1[i] = *unaligned++; buf2[i] = *unaligned++; } v0 = Load(d, buf0); v1 = Load(d, buf1); v2 = Load(d, buf2); } template HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2, Vec128& v3) { alignas(16) T buf0[N]; alignas(16) T buf1[N]; alignas(16) T buf2[N]; alignas(16) T buf3[N]; for (size_t i = 0; i < N; ++i) { buf0[i] = *unaligned++; buf1[i] = *unaligned++; buf2[i] = *unaligned++; buf3[i] = *unaligned++; } v0 = Load(d, buf0); v1 = Load(d, buf1); v2 = Load(d, buf2); v3 = Load(d, buf3); } // ------------------------------ StoreInterleaved2/3/4 template HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, Simd /* tag */, T* HWY_RESTRICT unaligned) { for (size_t i = 0; i < N; ++i) { *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; } } template HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, const Vec128 v2, Simd /* tag */, T* HWY_RESTRICT unaligned) { for (size_t i = 0; i < N; ++i) { *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; *unaligned++ = v2.raw[i]; } } template HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, const Vec128 v2, const Vec128 v3, Simd /* tag */, T* HWY_RESTRICT unaligned) { for (size_t i = 0; i < N; ++i) { *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; *unaligned++ = v2.raw[i]; *unaligned++ = v3.raw[i]; } } // ------------------------------ Stream template HWY_API void Stream(const Vec128 v, Simd d, T* HWY_RESTRICT aligned) { Store(v, d, aligned); } // ------------------------------ Scatter template HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, T* base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); for (size_t i = 0; i < N; ++i) { uint8_t* const base8 = reinterpret_cast(base) + offset.raw[i]; CopyBytes(&v.raw[i], base8); // copy to bytes } } template HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); for (size_t i = 0; i < N; ++i) { base[index.raw[i]] = v.raw[i]; } } // ------------------------------ Gather template HWY_API Vec128 GatherOffset(Simd /* tag */, const T* base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); Vec128 v; for (size_t i = 0; i < N; ++i) { const uint8_t* base8 = reinterpret_cast(base) + offset.raw[i]; CopyBytes(base8, &v.raw[i]); // copy from bytes } return v; } template HWY_API Vec128 GatherIndex(Simd /* tag */, const T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); Vec128 v; for (size_t i = 0; i < N; ++i) { v.raw[i] = base[index.raw[i]]; } return v; } // ================================================== CONVERT // ConvertTo and DemoteTo with floating-point input and integer output truncate // (rounding toward zero). template HWY_API Vec128 PromoteTo(Simd /* tag */, Vec128 from) { static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); Vec128 ret; for (size_t i = 0; i < N; ++i) { // For bits Y > X, floatX->floatY and intX->intY are always representable. ret.raw[i] = static_cast(from.raw[i]); } return ret; } // MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, // so we overload for FromT=double and ToT={float,int32_t}. template HWY_API Vec128 DemoteTo(Simd /* tag */, Vec128 from) { Vec128 ret; for (size_t i = 0; i < N; ++i) { // Prevent ubsan errors when converting float to narrower integer/float if (std::isinf(from.raw[i]) || std::fabs(from.raw[i]) > static_cast(HighestValue())) { ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() : HighestValue(); continue; } ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API Vec128 DemoteTo(Simd /* tag */, Vec128 from) { Vec128 ret; for (size_t i = 0; i < N; ++i) { // Prevent ubsan errors when converting int32_t to narrower integer/int32_t if (std::isinf(from.raw[i]) || std::fabs(from.raw[i]) > static_cast(HighestValue())) { ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() : HighestValue(); continue; } ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API Vec128 DemoteTo(Simd /* tag */, Vec128 from) { static_assert(!IsFloat(), "FromT=double are handled above"); static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); Vec128 ret; for (size_t i = 0; i < N; ++i) { // Int to int: choose closest value in ToT to `from` (avoids UB) from.raw[i] = HWY_MIN(HWY_MAX(LimitsMin(), from.raw[i]), LimitsMax()); ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API Vec128 ReorderDemote2To( Simd dbf16, Vec128 a, Vec128 b) { const Repartition du32; const Vec128 b_in_lower = ShiftRight<16>(BitCast(du32, b)); // Avoid OddEven - we want the upper half of `a` even on big-endian systems. const Vec128 a_mask = Set(du32, 0xFFFF0000); return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); } template HWY_API Vec128 ReorderDemote2To(Simd /*d16*/, Vec128 a, Vec128 b) { const int16_t min = LimitsMin(); const int16_t max = LimitsMax(); Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); } for (size_t i = 0; i < N; ++i) { ret.raw[N + i] = static_cast(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); } return ret; } namespace detail { HWY_INLINE void StoreU16ToF16(const uint16_t val, hwy::float16_t* HWY_RESTRICT to) { CopySameSize(&val, to); } HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { uint16_t bits16; CopySameSize(from, &bits16); return bits16; } } // namespace detail template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { const uint16_t bits16 = detail::U16FromF16(&v.raw[i]); const uint32_t sign = static_cast(bits16 >> 15); const uint32_t biased_exp = (bits16 >> 10) & 0x1F; const uint32_t mantissa = bits16 & 0x3FF; // Subnormal or zero if (biased_exp == 0) { const float subnormal = (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); ret.raw[i] = sign ? -subnormal : subnormal; continue; } // Normalized: convert the representation directly (faster than // ldexp/tables). const uint32_t biased_exp32 = biased_exp + (127 - 15); const uint32_t mantissa32 = mantissa << (23 - 10); const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; CopySameSize(&bits32, &ret.raw[i]); } return ret; } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = F32FromBF16(v.raw[i]); } return ret; } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { uint32_t bits32; CopySameSize(&v.raw[i], &bits32); const uint32_t sign = bits32 >> 31; const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; const uint32_t mantissa32 = bits32 & 0x7FFFFF; const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); // Tiny or zero => zero. if (exp < -24) { ZeroBytes(&ret.raw[i]); continue; } uint32_t biased_exp16, mantissa16; // exp = [-24, -15] => subnormal if (exp < -14) { biased_exp16 = 0; const uint32_t sub_exp = static_cast(-14 - exp); HWY_DASSERT(1 <= sub_exp && sub_exp < 11); mantissa16 = static_cast((1u << (10 - sub_exp)) + (mantissa32 >> (13 + sub_exp))); } else { // exp = [-14, 15] biased_exp16 = static_cast(exp + 15); HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); mantissa16 = mantissa32 >> 13; } HWY_DASSERT(mantissa16 < 1024); const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; HWY_DASSERT(bits16 < 0x10000); const uint16_t narrowed = static_cast(bits16); // big-endian safe detail::StoreU16ToF16(narrowed, &ret.raw[i]); } return ret; } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = BF16FromF32(v.raw[i]); } return ret; } // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_API Vec128 ConvertTo(hwy::FloatTag /*tag*/, Simd /* tag */, Vec128 from) { static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); Vec128 ret; for (size_t i = 0; i < N; ++i) { // float## -> int##: return closest representable value. We cannot exactly // represent LimitsMax in FromT, so use double. const double f = static_cast(from.raw[i]); if (std::isinf(from.raw[i]) || std::fabs(f) > static_cast(LimitsMax())) { ret.raw[i] = std::signbit(from.raw[i]) ? LimitsMin() : LimitsMax(); continue; } ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API Vec128 ConvertTo(hwy::NonFloatTag /*tag*/, Simd /* tag */, Vec128 from) { static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); Vec128 ret; for (size_t i = 0; i < N; ++i) { // int## -> float##: no check needed ret.raw[i] = static_cast(from.raw[i]); } return ret; } } // namespace detail template HWY_API Vec128 ConvertTo(Simd d, Vec128 from) { return detail::ConvertTo(hwy::IsFloatTag(), d, from); } template HWY_API Vec128 U8FromU32(const Vec128 v) { return DemoteTo(Simd(), v); } // ------------------------------ Truncations template HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFF); } return ret; } template HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); } return ret; } template HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFFFFFFFFu); } return ret; } template HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFF); } return ret; } template HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); } return ret; } template HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFF); } return ret; } // ================================================== COMBINE template HWY_API Vec128 LowerHalf(Vec128 v) { Vec128 ret; CopyBytes(v.raw, ret.raw); return ret; } template HWY_API Vec128 LowerHalf(Simd /* tag */, Vec128 v) { return LowerHalf(v); } template HWY_API Vec128 UpperHalf(Simd /* tag */, Vec128 v) { Vec128 ret; CopyBytes(&v.raw[N / 2], ret.raw); return ret; } template HWY_API Vec128 ZeroExtendVector(Simd /* tag */, Vec128 v) { Vec128 ret; CopyBytes(v.raw, ret.raw); return ret; } template HWY_API Vec128 Combine(Simd /* tag */, Vec128 hi_half, Vec128 lo_half) { Vec128 ret; CopyBytes(lo_half.raw, &ret.raw[0]); CopyBytes(hi_half.raw, &ret.raw[N / 2]); return ret; } template HWY_API Vec128 ConcatLowerLower(Simd /* tag */, Vec128 hi, Vec128 lo) { Vec128 ret; CopyBytes(lo.raw, &ret.raw[0]); CopyBytes(hi.raw, &ret.raw[N / 2]); return ret; } template HWY_API Vec128 ConcatUpperUpper(Simd /* tag */, Vec128 hi, Vec128 lo) { Vec128 ret; CopyBytes(&lo.raw[N / 2], &ret.raw[0]); CopyBytes(&hi.raw[N / 2], &ret.raw[N / 2]); return ret; } template HWY_API Vec128 ConcatLowerUpper(Simd /* tag */, const Vec128 hi, const Vec128 lo) { Vec128 ret; CopyBytes(&lo.raw[N / 2], &ret.raw[0]); CopyBytes(hi.raw, &ret.raw[N / 2]); return ret; } template HWY_API Vec128 ConcatUpperLower(Simd /* tag */, Vec128 hi, Vec128 lo) { Vec128 ret; CopyBytes(lo.raw, &ret.raw[0]); CopyBytes(&hi.raw[N / 2], &ret.raw[N / 2]); return ret; } template HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, Vec128 lo) { Vec128 ret; for (size_t i = 0; i < N / 2; ++i) { ret.raw[i] = lo.raw[2 * i]; } for (size_t i = 0; i < N / 2; ++i) { ret.raw[N / 2 + i] = hi.raw[2 * i]; } return ret; } template HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, Vec128 lo) { Vec128 ret; for (size_t i = 0; i < N / 2; ++i) { ret.raw[i] = lo.raw[2 * i + 1]; } for (size_t i = 0; i < N / 2; ++i) { ret.raw[N / 2 + i] = hi.raw[2 * i + 1]; } return ret; } // ------------------------------ CombineShiftRightBytes template > HWY_API V CombineShiftRightBytes(Simd /* tag */, V hi, V lo) { V ret; const uint8_t* HWY_RESTRICT lo8 = reinterpret_cast(lo.raw); uint8_t* HWY_RESTRICT ret8 = reinterpret_cast(ret.raw); CopyBytes(lo8 + kBytes, ret8); CopyBytes(hi.raw, ret8 + sizeof(T) * N - kBytes); return ret; } // ------------------------------ ShiftLeftBytes template HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); Vec128 ret; uint8_t* HWY_RESTRICT ret8 = reinterpret_cast(ret.raw); ZeroBytes(ret8); CopyBytes(v.raw, ret8 + kBytes); return ret; } template HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { return ShiftLeftBytes(DFromV(), v); } // ------------------------------ ShiftLeftLanes template HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { return ShiftLeftLanes(DFromV(), v); } // ------------------------------ ShiftRightBytes template HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); Vec128 ret; const uint8_t* HWY_RESTRICT v8 = reinterpret_cast(v.raw); uint8_t* HWY_RESTRICT ret8 = reinterpret_cast(ret.raw); CopyBytes(v8 + kBytes, ret8); ZeroBytes(ret8 + sizeof(T) * N - kBytes); return ret; } // ------------------------------ ShiftRightLanes template HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ================================================== SWIZZLE template HWY_API T GetLane(const Vec128 v) { return v.raw[0]; } template HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { v.raw[i] = t; return v; } template HWY_API T ExtractLane(const Vec128 v, size_t i) { return v.raw[i]; } template HWY_API Vec128 DupEven(Vec128 v) { for (size_t i = 0; i < N; i += 2) { v.raw[i + 1] = v.raw[i]; } return v; } template HWY_API Vec128 DupOdd(Vec128 v) { for (size_t i = 0; i < N; i += 2) { v.raw[i] = v.raw[i + 1]; } return v; } template HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { for (size_t i = 0; i < N; i += 2) { odd.raw[i] = even.raw[i]; } return odd; } template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { return even; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { return v; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices128 { MakeSigned raw[N]; }; template HWY_API Indices128 IndicesFromVec(Simd, Vec128 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); Indices128 ret; CopyBytes(vec.raw, ret.raw); return ret; } template HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { return IndicesFromVec(d, LoadU(Simd(), idx)); } template HWY_API Vec128 TableLookupLanes(const Vec128 v, const Indices128 idx) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = v.raw[idx.raw[i]]; } return ret; } // ------------------------------ ReverseBlocks // Single block: no change template HWY_API Vec128 ReverseBlocks(Simd /* tag */, const Vec128 v) { return v; } // ------------------------------ Reverse template HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = v.raw[N - 1 - i]; } return ret; } template HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; i += 2) { ret.raw[i + 0] = v.raw[i + 1]; ret.raw[i + 1] = v.raw[i + 0]; } return ret; } template HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; i += 4) { ret.raw[i + 0] = v.raw[i + 3]; ret.raw[i + 1] = v.raw[i + 2]; ret.raw[i + 2] = v.raw[i + 1]; ret.raw[i + 3] = v.raw[i + 0]; } return ret; } template HWY_API Vec128 Reverse8(Simd /* tag */, const Vec128 v) { Vec128 ret; for (size_t i = 0; i < N; i += 8) { ret.raw[i + 0] = v.raw[i + 7]; ret.raw[i + 1] = v.raw[i + 6]; ret.raw[i + 2] = v.raw[i + 5]; ret.raw[i + 3] = v.raw[i + 4]; ret.raw[i + 4] = v.raw[i + 3]; ret.raw[i + 5] = v.raw[i + 2]; ret.raw[i + 6] = v.raw[i + 1]; ret.raw[i + 7] = v.raw[i + 0]; } return ret; } // ================================================== BLOCKWISE // ------------------------------ Shuffle* // Swap 32-bit halves in 64-bit halves. template HWY_API Vec128 Shuffle2301(const Vec128 v) { static_assert(sizeof(T) == 4, "Only for 32-bit"); static_assert(N == 2 || N == 4, "Does not make sense for N=1"); return Reverse2(DFromV(), v); } // Swap 64-bit halves template HWY_API Vec128 Shuffle1032(const Vec128 v) { static_assert(sizeof(T) == 4, "Only for 32-bit"); Vec128 ret; ret.raw[3] = v.raw[1]; ret.raw[2] = v.raw[0]; ret.raw[1] = v.raw[3]; ret.raw[0] = v.raw[2]; return ret; } template HWY_API Vec128 Shuffle01(const Vec128 v) { static_assert(sizeof(T) == 8, "Only for 64-bit"); return Reverse2(DFromV(), v); } // Rotate right 32 bits template HWY_API Vec128 Shuffle0321(const Vec128 v) { Vec128 ret; ret.raw[3] = v.raw[0]; ret.raw[2] = v.raw[3]; ret.raw[1] = v.raw[2]; ret.raw[0] = v.raw[1]; return ret; } // Rotate left 32 bits template HWY_API Vec128 Shuffle2103(const Vec128 v) { Vec128 ret; ret.raw[3] = v.raw[2]; ret.raw[2] = v.raw[1]; ret.raw[1] = v.raw[0]; ret.raw[0] = v.raw[3]; return ret; } template HWY_API Vec128 Shuffle0123(const Vec128 v) { return Reverse4(DFromV(), v); } // ------------------------------ Broadcast/splat any lane template HWY_API Vec128 Broadcast(Vec128 v) { for (size_t i = 0; i < N; ++i) { v.raw[i] = v.raw[kLane]; } return v; } // ------------------------------ TableLookupBytes, TableLookupBytesOr0 template HWY_API Vec128 TableLookupBytes(const Vec128 v, const Vec128 indices) { const uint8_t* HWY_RESTRICT v_bytes = reinterpret_cast(v.raw); const uint8_t* HWY_RESTRICT idx_bytes = reinterpret_cast(indices.raw); Vec128 ret; uint8_t* HWY_RESTRICT ret_bytes = reinterpret_cast(ret.raw); for (size_t i = 0; i < NI * sizeof(TI); ++i) { const size_t idx = idx_bytes[i]; // Avoid out of bounds reads. ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; } return ret; } template HWY_API Vec128 TableLookupBytesOr0(const Vec128 v, const Vec128 indices) { // Same as TableLookupBytes, which already returns 0 if out of bounds. return TableLookupBytes(v, indices); } // ------------------------------ InterleaveLower/InterleaveUpper template HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { Vec128 ret; for (size_t i = 0; i < N / 2; ++i) { ret.raw[2 * i + 0] = a.raw[i]; ret.raw[2 * i + 1] = b.raw[i]; } return ret; } // Additional overload for the optional tag (also for 256/512). template HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { return InterleaveLower(a, b); } template HWY_API Vec128 InterleaveUpper(Simd /* tag */, const Vec128 a, const Vec128 b) { Vec128 ret; for (size_t i = 0; i < N / 2; ++i) { ret.raw[2 * i + 0] = a.raw[N / 2 + i]; ret.raw[2 * i + 1] = b.raw[N / 2 + i]; } return ret; } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template >> HWY_API VFromD ZipLower(V a, V b) { return BitCast(DW(), InterleaveLower(a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipLower(DW dw, V a, V b) { return BitCast(dw, InterleaveLower(D(), a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipUpper(DW dw, V a, V b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } // ================================================== MASK template HWY_API bool AllFalse(Simd /* tag */, const Mask128 mask) { typename Mask128::Raw or_sum = 0; for (size_t i = 0; i < N; ++i) { or_sum |= mask.bits[i]; } return or_sum == 0; } template HWY_API bool AllTrue(Simd /* tag */, const Mask128 mask) { constexpr uint64_t kAll = LimitsMax::Raw>(); uint64_t and_sum = kAll; for (size_t i = 0; i < N; ++i) { and_sum &= mask.bits[i]; } return and_sum == kAll; } // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask128 LoadMaskBits(Simd /* tag */, const uint8_t* HWY_RESTRICT bits) { Mask128 m; for (size_t i = 0; i < N; ++i) { const size_t bit = size_t{1} << (i & 7); const size_t idx_byte = i >> 3; m.bits[i] = Mask128::FromBool((bits[idx_byte] & bit) != 0); } return m; } // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, uint8_t* bits) { bits[0] = 0; if (N > 8) bits[1] = 0; // N <= 16, so max two bytes for (size_t i = 0; i < N; ++i) { const size_t bit = size_t{1} << (i & 7); const size_t idx_byte = i >> 3; if (mask.bits[i]) { bits[idx_byte] = static_cast(bits[idx_byte] | bit); } } return N > 8 ? 2 : 1; } template HWY_API size_t CountTrue(Simd /* tag */, const Mask128 mask) { size_t count = 0; for (size_t i = 0; i < N; ++i) { count += mask.bits[i] != 0; } return count; } template HWY_API size_t FindKnownFirstTrue(Simd /* tag */, const Mask128 mask) { for (size_t i = 0; i < N; ++i) { if (mask.bits[i] != 0) return i; } HWY_DASSERT(false); return 0; } template HWY_API intptr_t FindFirstTrue(Simd /* tag */, const Mask128 mask) { for (size_t i = 0; i < N; ++i) { if (mask.bits[i] != 0) return static_cast(i); } return intptr_t{-1}; } // ------------------------------ Compress template struct CompressIsPartition { enum { value = (sizeof(T) != 1) }; }; template HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { size_t count = 0; Vec128 ret; for (size_t i = 0; i < N; ++i) { if (mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } for (size_t i = 0; i < N; ++i) { if (!mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } HWY_DASSERT(count == N); return ret; } // ------------------------------ CompressNot template HWY_API Vec128 CompressNot(Vec128 v, const Mask128 mask) { size_t count = 0; Vec128 ret; for (size_t i = 0; i < N; ++i) { if (!mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } for (size_t i = 0; i < N; ++i) { if (mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } HWY_DASSERT(count == N); return ret; } // ------------------------------ CompressBlocksNot HWY_API Vec128 CompressBlocksNot(Vec128 v, Mask128 /* m */) { return v; } // ------------------------------ CompressBits template HWY_API Vec128 CompressBits(Vec128 v, const uint8_t* HWY_RESTRICT bits) { return Compress(v, LoadMaskBits(Simd(), bits)); } // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, Simd /* tag */, T* HWY_RESTRICT unaligned) { size_t count = 0; for (size_t i = 0; i < N; ++i) { if (mask.bits[i]) { unaligned[count++] = v.raw[i]; } } return count; } // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec128 v, const Mask128 mask, Simd d, T* HWY_RESTRICT unaligned) { return CompressStore(v, mask, d, unaligned); } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(Vec128 v, const uint8_t* HWY_RESTRICT bits, Simd d, T* HWY_RESTRICT unaligned) { const Mask128 mask = LoadMaskBits(d, bits); StoreU(Compress(v, mask), d, unaligned); return CountTrue(d, mask); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) template HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, Vec128 a, Vec128 b, const Vec128 sum0, Vec128& sum1) { const Rebind du32; using VU32 = VFromD; const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 // Avoid ZipLower/Upper so this also works on big-endian systems. const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); const VU32 ao = And(BitCast(du32, a), odd); const VU32 be = ShiftLeft<16>(BitCast(du32, b)); const VU32 bo = And(BitCast(du32, b), odd); sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); } template HWY_API Vec128 ReorderWidenMulAccumulate( Simd d32, Vec128 a, Vec128 b, const Vec128 sum0, Vec128& sum1) { using VI32 = VFromD; // Manual sign extension requires two shifts for even lanes. const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); const VI32 ao = ShiftRight<16>(BitCast(d32, a)); const VI32 bo = ShiftRight<16>(BitCast(d32, b)); sum1 = Add(Mul(ao, bo), sum1); return Add(Mul(ae, be), sum0); } // ------------------------------ RearrangeToOddPlusEven template HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { return Add(sum0, sum1); } // ================================================== REDUCTIONS template HWY_API Vec128 SumOfLanes(Simd d, const Vec128 v) { T sum = T{0}; for (size_t i = 0; i < N; ++i) { sum += v.raw[i]; } return Set(d, sum); } template HWY_API Vec128 MinOfLanes(Simd d, const Vec128 v) { T min = HighestValue(); for (size_t i = 0; i < N; ++i) { min = HWY_MIN(min, v.raw[i]); } return Set(d, min); } template HWY_API Vec128 MaxOfLanes(Simd d, const Vec128 v) { T max = LowestValue(); for (size_t i = 0; i < N; ++i) { max = HWY_MAX(max, v.raw[i]); } return Set(d, max); } // ================================================== OPS WITH DEPENDENCIES // ------------------------------ MulEven/Odd 64x64 (UpperHalf) HWY_INLINE Vec128 MulEven(const Vec128 a, const Vec128 b) { alignas(16) uint64_t mul[2]; mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); return Load(Full128(), mul); } HWY_INLINE Vec128 MulOdd(const Vec128 a, const Vec128 b) { alignas(16) uint64_t mul[2]; const Half> d2; mul[0] = Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); return Load(Full128(), mul); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();