// Copyright 2021 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 256-bit WASM vectors and operations. Experimental. // External include guard in highway.h - see comment there. // For half-width vectors. Already includes base.h and shared-inl.h. #include "hwy/ops/wasm_128-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template class Vec256 { public: using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec256& operator*=(const Vec256 other) { return *this = (*this * other); } HWY_INLINE Vec256& operator/=(const Vec256 other) { return *this = (*this / other); } HWY_INLINE Vec256& operator+=(const Vec256 other) { return *this = (*this + other); } HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } HWY_INLINE Vec256& operator|=(const Vec256 other) { return *this = (*this | other); } HWY_INLINE Vec256& operator^=(const Vec256 other) { return *this = (*this ^ other); } Vec128 v0; Vec128 v1; }; template struct Mask256 { Mask128 m0; Mask128 m1; }; // ------------------------------ Zero // Avoid VFromD here because it is defined in terms of Zero. template HWY_API Vec256> Zero(D d) { const Half dh; Vec256> ret; ret.v0 = ret.v1 = Zero(dh); return ret; } // ------------------------------ BitCast template HWY_API VFromD BitCast(D d, Vec256 v) { const Half dh; VFromD ret; ret.v0 = BitCast(dh, v.v0); ret.v1 = BitCast(dh, v.v1); return ret; } // ------------------------------ ResizeBitCast // 32-byte vector to 32-byte vector: Same as BitCast template HWY_API VFromD ResizeBitCast(D d, FromV v) { return BitCast(d, v); } // <= 16-byte vector to 32-byte vector template HWY_API VFromD ResizeBitCast(D d, FromV v) { const Half dh; VFromD ret; ret.v0 = ResizeBitCast(dh, v); ret.v1 = Zero(dh); return ret; } // 32-byte vector to <= 16-byte vector template HWY_API VFromD ResizeBitCast(D d, FromV v) { return ResizeBitCast(d, v.v0); } // ------------------------------ Set template HWY_API VFromD Set(D d, const T2 t) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Set(dh, static_cast>(t)); return ret; } // Undefined, Iota defined in wasm_128. // ================================================== ARITHMETIC template HWY_API Vec256 operator+(Vec256 a, const Vec256 b) { a.v0 += b.v0; a.v1 += b.v1; return a; } template HWY_API Vec256 operator-(Vec256 a, const Vec256 b) { a.v0 -= b.v0; a.v1 -= b.v1; return a; } // ------------------------------ SumsOf8 HWY_API Vec256 SumsOf8(const Vec256 v) { Vec256 ret; ret.v0 = SumsOf8(v.v0); ret.v1 = SumsOf8(v.v1); return ret; } template HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { a.v0 = SaturatedAdd(a.v0, b.v0); a.v1 = SaturatedAdd(a.v1, b.v1); return a; } template HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { a.v0 = SaturatedSub(a.v0, b.v0); a.v1 = SaturatedSub(a.v1, b.v1); return a; } template HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { a.v0 = AverageRound(a.v0, b.v0); a.v1 = AverageRound(a.v1, b.v1); return a; } template HWY_API Vec256 Abs(Vec256 v) { v.v0 = Abs(v.v0); v.v1 = Abs(v.v1); return v; } // ------------------------------ Shift lanes by constant #bits template HWY_API Vec256 ShiftLeft(Vec256 v) { v.v0 = ShiftLeft(v.v0); v.v1 = ShiftLeft(v.v1); return v; } template HWY_API Vec256 ShiftRight(Vec256 v) { v.v0 = ShiftRight(v.v0); v.v1 = ShiftRight(v.v1); return v; } // ------------------------------ RotateRight (ShiftRight, Or) template HWY_API Vec256 RotateRight(const Vec256 v) { constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } // ------------------------------ Shift lanes by same variable #bits template HWY_API Vec256 ShiftLeftSame(Vec256 v, const int bits) { v.v0 = ShiftLeftSame(v.v0, bits); v.v1 = ShiftLeftSame(v.v1, bits); return v; } template HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { v.v0 = ShiftRightSame(v.v0, bits); v.v1 = ShiftRightSame(v.v1, bits); return v; } // ------------------------------ Min, Max template HWY_API Vec256 Min(Vec256 a, const Vec256 b) { a.v0 = Min(a.v0, b.v0); a.v1 = Min(a.v1, b.v1); return a; } template HWY_API Vec256 Max(Vec256 a, const Vec256 b) { a.v0 = Max(a.v0, b.v0); a.v1 = Max(a.v1, b.v1); return a; } // ------------------------------ Integer multiplication template HWY_API Vec256 operator*(Vec256 a, const Vec256 b) { a.v0 *= b.v0; a.v1 *= b.v1; return a; } template HWY_API Vec256 MulHigh(Vec256 a, const Vec256 b) { a.v0 = MulHigh(a.v0, b.v0); a.v1 = MulHigh(a.v1, b.v1); return a; } template HWY_API Vec256 MulFixedPoint15(Vec256 a, const Vec256 b) { a.v0 = MulFixedPoint15(a.v0, b.v0); a.v1 = MulFixedPoint15(a.v1, b.v1); return a; } // Cannot use MakeWide because that returns uint128_t for uint64_t, but we want // uint64_t. HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { Vec256 ret; ret.v0 = MulEven(a.v0, b.v0); ret.v1 = MulEven(a.v1, b.v1); return ret; } HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { Vec256 ret; ret.v0 = MulEven(a.v0, b.v0); ret.v1 = MulEven(a.v1, b.v1); return ret; } HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { Vec256 ret; ret.v0 = MulEven(a.v0, b.v0); ret.v1 = MulEven(a.v1, b.v1); return ret; } HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { Vec256 ret; ret.v0 = MulOdd(a.v0, b.v0); ret.v1 = MulOdd(a.v1, b.v1); return ret; } // ------------------------------ Negate template HWY_API Vec256 Neg(Vec256 v) { v.v0 = Neg(v.v0); v.v1 = Neg(v.v1); return v; } // ------------------------------ Floating-point division template HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { a.v0 /= b.v0; a.v1 /= b.v1; return a; } // Approximate reciprocal HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { const Vec256 one = Set(Full256(), 1.0f); return one / v; } // Absolute value of difference. HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, Vec256 add) { mul.v0 = MulAdd(mul.v0, x.v0, add.v0); mul.v1 = MulAdd(mul.v1, x.v1, add.v1); return mul; } HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, Vec256 add) { mul.v0 = NegMulAdd(mul.v0, x.v0, add.v0); mul.v1 = NegMulAdd(mul.v1, x.v1, add.v1); return mul; } HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, Vec256 sub) { mul.v0 = MulSub(mul.v0, x.v0, sub.v0); mul.v1 = MulSub(mul.v1, x.v1, sub.v1); return mul; } HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, Vec256 sub) { mul.v0 = NegMulSub(mul.v0, x.v0, sub.v0); mul.v1 = NegMulSub(mul.v1, x.v1, sub.v1); return mul; } // ------------------------------ Floating-point square root template HWY_API Vec256 Sqrt(Vec256 v) { v.v0 = Sqrt(v.v0); v.v1 = Sqrt(v.v1); return v; } // Approximate reciprocal square root HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { // TODO(eustas): find cheaper a way to calculate this. const Vec256 one = Set(Full256(), 1.0f); return one / Sqrt(v); } // ------------------------------ Floating-point rounding // Toward nearest integer, ties to even HWY_API Vec256 Round(Vec256 v) { v.v0 = Round(v.v0); v.v1 = Round(v.v1); return v; } // Toward zero, aka truncate HWY_API Vec256 Trunc(Vec256 v) { v.v0 = Trunc(v.v0); v.v1 = Trunc(v.v1); return v; } // Toward +infinity, aka ceiling HWY_API Vec256 Ceil(Vec256 v) { v.v0 = Ceil(v.v0); v.v1 = Ceil(v.v1); return v; } // Toward -infinity, aka floor HWY_API Vec256 Floor(Vec256 v) { v.v0 = Floor(v.v0); v.v1 = Floor(v.v1); return v; } // ------------------------------ Floating-point classification template HWY_API Mask256 IsNaN(const Vec256 v) { return v != v; } template HWY_API Mask256 IsInf(const Vec256 v) { const DFromV d; const RebindToSigned di; const VFromD vi = BitCast(di, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. template HWY_API Mask256 IsFinite(const Vec256 v) { const DFromV d; const RebindToUnsigned du; const RebindToSigned di; // cheaper than unsigned comparison const VFromD vu = BitCast(du, v); // 'Shift left' to clear the sign bit, then right so we can compare with the // max exponent (cannot compare with MaxExponentTimes2 directly because it is // negative and non-negative floats would be greater). const VFromD exp = BitCast(di, ShiftRight() + 1>(Add(vu, vu))); return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); } // ================================================== COMPARE // Comparisons fill a lane with 1-bits if the condition is true, else 0. template > HWY_API MFromD RebindMask(DTo /*tag*/, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return MFromD{Mask128{m.m0.raw}, Mask128{m.m1.raw}}; } template HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } template HWY_API Mask256 operator==(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator==(a.v0, b.v0); m.m1 = operator==(a.v1, b.v1); return m; } template HWY_API Mask256 operator!=(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator!=(a.v0, b.v0); m.m1 = operator!=(a.v1, b.v1); return m; } template HWY_API Mask256 operator<(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator<(a.v0, b.v0); m.m1 = operator<(a.v1, b.v1); return m; } template HWY_API Mask256 operator>(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator>(a.v0, b.v0); m.m1 = operator>(a.v1, b.v1); return m; } template HWY_API Mask256 operator<=(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator<=(a.v0, b.v0); m.m1 = operator<=(a.v1, b.v1); return m; } template HWY_API Mask256 operator>=(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator>=(a.v0, b.v0); m.m1 = operator>=(a.v1, b.v1); return m; } // ------------------------------ FirstN (Iota, Lt) template HWY_API MFromD FirstN(const D d, size_t num) { const RebindToSigned di; // Signed comparisons may be cheaper. using TI = TFromD; return RebindMask(d, Iota(di, 0) < Set(di, static_cast(num))); } // ================================================== LOGICAL template HWY_API Vec256 Not(Vec256 v) { v.v0 = Not(v.v0); v.v1 = Not(v.v1); return v; } template HWY_API Vec256 And(Vec256 a, Vec256 b) { a.v0 = And(a.v0, b.v0); a.v1 = And(a.v1, b.v1); return a; } template HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { not_mask.v0 = AndNot(not_mask.v0, mask.v0); not_mask.v1 = AndNot(not_mask.v1, mask.v1); return not_mask; } template HWY_API Vec256 Or(Vec256 a, Vec256 b) { a.v0 = Or(a.v0, b.v0); a.v1 = Or(a.v1, b.v1); return a; } template HWY_API Vec256 Xor(Vec256 a, Vec256 b) { a.v0 = Xor(a.v0, b.v0); a.v1 = Xor(a.v1, b.v1); return a; } template HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { return Xor(x1, Xor(x2, x3)); } template HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { return Or(o1, Or(o2, o3)); } template HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { return Or(o, And(a1, a2)); } template HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { return IfThenElse(MaskFromVec(mask), yes, no); } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { return And(a, b); } template HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { return Or(a, b); } template HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { return Xor(a, b); } // ------------------------------ CopySign template HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const auto msb = SignBit(DFromV()); return Or(AndNot(msb, magn), And(msb, sign)); } template HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); return Or(abs, And(SignBit(DFromV()), sign)); } // ------------------------------ Mask // Mask and Vec are the same (true = FF..FF). template HWY_API Mask256 MaskFromVec(const Vec256 v) { Mask256 m; m.m0 = MaskFromVec(v.v0); m.m1 = MaskFromVec(v.v1); return m; } template > HWY_API Vec256 VecFromMask(D d, Mask256 m) { const Half dh; Vec256 v; v.v0 = VecFromMask(dh, m.m0); v.v1 = VecFromMask(dh, m.m1); return v; } // mask ? yes : no template HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); return yes; } // mask ? yes : 0 template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return yes & VecFromMask(DFromV(), mask); } // mask ? 0 : no template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return AndNot(VecFromMask(DFromV(), mask), no); } template HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); return v; } template HWY_API Vec256 ZeroIfNegative(Vec256 v) { return IfThenZeroElse(v < Zero(DFromV()), v); } // ------------------------------ Mask logical template HWY_API Mask256 Not(const Mask256 m) { return MaskFromVec(Not(VecFromMask(Full256(), m))); } template HWY_API Mask256 And(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } // ------------------------------ Shl (BroadcastSignBit, IfThenElse) template HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { v.v0 = operator<<(v.v0, bits.v0); v.v1 = operator<<(v.v1, bits.v1); return v; } // ------------------------------ Shr (BroadcastSignBit, IfThenElse) template HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { v.v0 = operator>>(v.v0, bits.v0); v.v1 = operator>>(v.v1, bits.v1); return v; } // ------------------------------ BroadcastSignBit (compare, VecFromMask) template HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { const DFromV d; return VecFromMask(d, v < Zero(d)); } // ================================================== MEMORY // ------------------------------ Load template HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { const Half dh; VFromD ret; ret.v0 = Load(dh, aligned); ret.v1 = Load(dh, aligned + Lanes(dh)); return ret; } template > HWY_API Vec256 MaskedLoad(Mask256 m, D d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } template > HWY_API Vec256 MaskedLoadOr(Vec256 v, Mask256 m, D d, const T* HWY_RESTRICT aligned) { return IfThenElse(m, Load(d, aligned), v); } // LoadU == Load. template HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { return Load(d, p); } template HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Load(dh, p); return ret; } // ------------------------------ Store template > HWY_API void Store(Vec256 v, D d, T* HWY_RESTRICT aligned) { const Half dh; Store(v.v0, dh, aligned); Store(v.v1, dh, aligned + Lanes(dh)); } // StoreU == Store. template > HWY_API void StoreU(Vec256 v, D d, T* HWY_RESTRICT p) { Store(v, d, p); } template > HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, T* HWY_RESTRICT p) { StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); } // ------------------------------ Stream template > HWY_API void Stream(Vec256 v, D d, T* HWY_RESTRICT aligned) { // Same as aligned stores. Store(v, d, aligned); } // ------------------------------ Scatter, Gather defined in wasm_128 // ================================================== SWIZZLE // ------------------------------ ExtractLane template HWY_API T ExtractLane(const Vec256 v, size_t i) { alignas(32) T lanes[32 / sizeof(T)]; Store(v, DFromV(), lanes); return lanes[i]; } // ------------------------------ InsertLane template HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { DFromV d; alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } // ------------------------------ LowerHalf template > HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { return v.v0; } template HWY_API Vec128 LowerHalf(Vec256 v) { return v.v0; } // ------------------------------ GetLane (LowerHalf) template HWY_API T GetLane(const Vec256 v) { return GetLane(LowerHalf(v)); } // ------------------------------ ShiftLeftBytes template > HWY_API Vec256 ShiftLeftBytes(D d, Vec256 v) { const Half dh; v.v0 = ShiftLeftBytes(dh, v.v0); v.v1 = ShiftLeftBytes(dh, v.v1); return v; } template HWY_API Vec256 ShiftLeftBytes(Vec256 v) { return ShiftLeftBytes(DFromV(), v); } // ------------------------------ ShiftLeftLanes template > HWY_API Vec256 ShiftLeftLanes(D d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { return ShiftLeftLanes(DFromV(), v); } // ------------------------------ ShiftRightBytes template > HWY_API Vec256 ShiftRightBytes(D d, Vec256 v) { const Half dh; v.v0 = ShiftRightBytes(dh, v.v0); v.v1 = ShiftRightBytes(dh, v.v1); return v; } // ------------------------------ ShiftRightLanes template > HWY_API Vec256 ShiftRightLanes(D d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ UpperHalf (ShiftRightBytes) template > HWY_API Vec128 UpperHalf(D /* tag */, const Vec256 v) { return v.v1; } // ------------------------------ CombineShiftRightBytes template > HWY_API Vec256 CombineShiftRightBytes(D d, Vec256 hi, Vec256 lo) { const Half dh; hi.v0 = CombineShiftRightBytes(dh, hi.v0, lo.v0); hi.v1 = CombineShiftRightBytes(dh, hi.v1, lo.v1); return hi; } // ------------------------------ Broadcast/splat any lane template HWY_API Vec256 Broadcast(const Vec256 v) { Vec256 ret; ret.v0 = Broadcast(v.v0); ret.v1 = Broadcast(v.v1); return ret; } // ------------------------------ TableLookupBytes // Both full template HWY_API Vec256 TableLookupBytes(const Vec256 bytes, Vec256 from) { from.v0 = TableLookupBytes(bytes.v0, from.v0); from.v1 = TableLookupBytes(bytes.v1, from.v1); return from; } // Partial index vector template HWY_API Vec128 TableLookupBytes(Vec256 bytes, const Vec128 from) { // First expand to full 128, then 256. const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); const auto tbl_full = TableLookupBytes(bytes, from_256); // Shrink to 128, then partial. return Vec128{LowerHalf(Full128(), tbl_full).raw}; } // Partial table vector template HWY_API Vec256 TableLookupBytes(Vec128 bytes, const Vec256 from) { // First expand to full 128, then 256. const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); return TableLookupBytes(bytes_256, from); } // Partial both are handled by wasm_128. template HWY_API VI TableLookupBytesOr0(V bytes, VI from) { // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. return TableLookupBytes(bytes, from); } // ------------------------------ Hard-coded shuffles template HWY_API Vec256 Shuffle01(Vec256 v) { v.v0 = Shuffle01(v.v0); v.v1 = Shuffle01(v.v1); return v; } template HWY_API Vec256 Shuffle2301(Vec256 v) { v.v0 = Shuffle2301(v.v0); v.v1 = Shuffle2301(v.v1); return v; } template HWY_API Vec256 Shuffle1032(Vec256 v) { v.v0 = Shuffle1032(v.v0); v.v1 = Shuffle1032(v.v1); return v; } template HWY_API Vec256 Shuffle0321(Vec256 v) { v.v0 = Shuffle0321(v.v0); v.v1 = Shuffle0321(v.v1); return v; } template HWY_API Vec256 Shuffle2103(Vec256 v) { v.v0 = Shuffle2103(v.v0); v.v1 = Shuffle2103(v.v1); return v; } template HWY_API Vec256 Shuffle0123(Vec256 v) { v.v0 = Shuffle0123(v.v0); v.v1 = Shuffle0123(v.v1); return v; } // Used by generic_ops-inl.h namespace detail { template HWY_API Vec256 ShuffleTwo2301(Vec256 a, const Vec256 b) { a.v0 = ShuffleTwo2301(a.v0, b.v0); a.v1 = ShuffleTwo2301(a.v1, b.v1); return a; } template HWY_API Vec256 ShuffleTwo1230(Vec256 a, const Vec256 b) { a.v0 = ShuffleTwo1230(a.v0, b.v0); a.v1 = ShuffleTwo1230(a.v1, b.v1); return a; } template HWY_API Vec256 ShuffleTwo3012(Vec256 a, const Vec256 b) { a.v0 = ShuffleTwo3012(a.v0, b.v0); a.v1 = ShuffleTwo3012(a.v1, b.v1); return a; } } // namespace detail // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices256 { __v128_u i0; __v128_u i1; }; template , typename TI> HWY_API Indices256 IndicesFromVec(D /* tag */, Vec256 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); Indices256 ret; ret.i0 = vec.v0.raw; ret.i1 = vec.v1.raw; return ret; } template HWY_API Indices256> SetTableIndices(D d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec256 TableLookupLanes(const Vec256 v, Indices256 idx) { const DFromV d; const Half dh; const auto idx_i0 = IndicesFromVec(dh, Vec128{idx.i0}); const auto idx_i1 = IndicesFromVec(dh, Vec128{idx.i1}); Vec256 result; result.v0 = TwoTablesLookupLanes(v.v0, v.v1, idx_i0); result.v1 = TwoTablesLookupLanes(v.v0, v.v1, idx_i1); return result; } template HWY_API Vec256 TableLookupLanesOr0(Vec256 v, Indices256 idx) { // The out of bounds behavior will already zero lanes. return TableLookupLanesOr0(v, idx); } template HWY_API Vec256 TwoTablesLookupLanes(const Vec256 a, const Vec256 b, Indices256 idx) { const DFromV d; const Half dh; const RebindToUnsigned du; using TU = MakeUnsigned; constexpr size_t kLanesPerVect = 32 / sizeof(TU); Vec256 vi; vi.v0 = Vec128{idx.i0}; vi.v1 = Vec128{idx.i1}; const auto vmod = vi & Set(du, TU{kLanesPerVect - 1}); const auto is_lo = RebindMask(d, vi == vmod); const auto idx_i0 = IndicesFromVec(dh, vmod.v0); const auto idx_i1 = IndicesFromVec(dh, vmod.v1); Vec256 result_lo; Vec256 result_hi; result_lo.v0 = TwoTablesLookupLanes(a.v0, a.v1, idx_i0); result_lo.v1 = TwoTablesLookupLanes(a.v0, a.v1, idx_i1); result_hi.v0 = TwoTablesLookupLanes(b.v0, b.v1, idx_i0); result_hi.v1 = TwoTablesLookupLanes(b.v0, b.v1, idx_i1); return IfThenElse(is_lo, result_lo, result_hi); } // ------------------------------ Reverse template > HWY_API Vec256 Reverse(D d, const Vec256 v) { const Half dh; Vec256 ret; ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order ret.v0 = Reverse(dh, v.v1); return ret; } // ------------------------------ Reverse2 template > HWY_API Vec256 Reverse2(D d, Vec256 v) { const Half dh; v.v0 = Reverse2(dh, v.v0); v.v1 = Reverse2(dh, v.v1); return v; } // ------------------------------ Reverse4 // Each block has only 2 lanes, so swap blocks and their lanes. template , HWY_IF_T_SIZE(T, 8)> HWY_API Vec256 Reverse4(D d, const Vec256 v) { const Half dh; Vec256 ret; ret.v0 = Reverse2(dh, v.v1); // swapped ret.v1 = Reverse2(dh, v.v0); return ret; } template , HWY_IF_NOT_T_SIZE(T, 8)> HWY_API Vec256 Reverse4(D d, Vec256 v) { const Half dh; v.v0 = Reverse4(dh, v.v0); v.v1 = Reverse4(dh, v.v1); return v; } // ------------------------------ Reverse8 template , HWY_IF_T_SIZE(T, 8)> HWY_API Vec256 Reverse8(D /* tag */, Vec256 /* v */) { HWY_ASSERT(0); // don't have 8 u64 lanes } // Each block has only 4 lanes, so swap blocks and their lanes. template , HWY_IF_T_SIZE(T, 4)> HWY_API Vec256 Reverse8(D d, const Vec256 v) { const Half dh; Vec256 ret; ret.v0 = Reverse4(dh, v.v1); // swapped ret.v1 = Reverse4(dh, v.v0); return ret; } template , HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> HWY_API Vec256 Reverse8(D d, Vec256 v) { const Half dh; v.v0 = Reverse8(dh, v.v0); v.v1 = Reverse8(dh, v.v1); return v; } // ------------------------------ InterleaveLower template HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { a.v0 = InterleaveLower(a.v0, b.v0); a.v1 = InterleaveLower(a.v1, b.v1); return a; } // wasm_128 already defines a template with D, V, V args. // ------------------------------ InterleaveUpper (UpperHalf) template > HWY_API Vec256 InterleaveUpper(D d, Vec256 a, Vec256 b) { const Half dh; a.v0 = InterleaveUpper(dh, a.v0, b.v0); a.v1 = InterleaveUpper(dh, a.v1, b.v1); return a; } // ------------------------------ ZipLower/ZipUpper defined in wasm_128 // ================================================== COMBINE // ------------------------------ Combine (InterleaveLower) template > HWY_API Vec256 Combine(D /* d */, Vec128 hi, Vec128 lo) { Vec256 ret; ret.v1 = hi; ret.v0 = lo; return ret; } // ------------------------------ ZeroExtendVector (Combine) template > HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { const Half dh; return Combine(d, Zero(dh), lo); } // ------------------------------ ZeroExtendResizeBitCast namespace detail { template HWY_INLINE VFromD ZeroExtendResizeBitCast( hwy::SizeTag /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */, DTo d_to, DFrom d_from, VFromD v) { const Half dh_to; return ZeroExtendVector(d_to, ZeroExtendResizeBitCast(dh_to, d_from, v)); } } // namespace detail // ------------------------------ ConcatLowerLower template > HWY_API Vec256 ConcatLowerLower(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v0; ret.v0 = lo.v0; return ret; } // ------------------------------ ConcatUpperUpper template > HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v1; ret.v0 = lo.v1; return ret; } // ------------------------------ ConcatLowerUpper template > HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v0; ret.v0 = lo.v1; return ret; } // ------------------------------ ConcatUpperLower template > HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v1; ret.v0 = lo.v0; return ret; } // ------------------------------ ConcatOdd template > HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { const Half dh; Vec256 ret; ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); return ret; } // ------------------------------ ConcatEven template > HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { const Half dh; Vec256 ret; ret.v0 = ConcatEven(dh, lo.v1, lo.v0); ret.v1 = ConcatEven(dh, hi.v1, hi.v0); return ret; } // ------------------------------ DupEven template HWY_API Vec256 DupEven(Vec256 v) { v.v0 = DupEven(v.v0); v.v1 = DupEven(v.v1); return v; } // ------------------------------ DupOdd template HWY_API Vec256 DupOdd(Vec256 v) { v.v0 = DupOdd(v.v0); v.v1 = DupOdd(v.v1); return v; } // ------------------------------ OddEven template HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { a.v0 = OddEven(a.v0, b.v0); a.v1 = OddEven(a.v1, b.v1); return a; } // ------------------------------ OddEvenBlocks template HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { odd.v0 = even.v0; return odd; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { Vec256 ret; ret.v0 = v.v1; // swapped order ret.v1 = v.v0; return ret; } // ------------------------------ ReverseBlocks template > HWY_API Vec256 ReverseBlocks(D /* tag */, const Vec256 v) { return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse } // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) namespace detail { // Unsigned: zero-extend. template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_u16x8_extend_high_u8x16(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{ wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_u16x8_extend_high_u8x16(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{ wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_u32x4_extend_high_u16x8(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_u64x2_extend_high_u32x4(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_u32x4_extend_high_u16x8(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_u64x2_extend_high_u32x4(v.raw)}; } // Signed: replicate sign bit. template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_i16x8_extend_high_i8x16(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{ wasm_i32x4_extend_high_i16x8(wasm_i16x8_extend_high_i8x16(v.raw))}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_i32x4_extend_high_i16x8(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128{wasm_i64x2_extend_high_i32x4(v.raw)}; } template HWY_API Vec128 PromoteUpperTo(D dd, Vec128 v) { // There is no wasm_f64x2_convert_high_i32x4. const Full64 di32h; return PromoteTo(dd, UpperHalf(di32h, v)); } template HWY_API Vec128 PromoteUpperTo(D df32, Vec128 v) { const RebindToSigned di32; const RebindToUnsigned du32; // Expand to u32 so we can shift. const auto bits16 = PromoteUpperTo(du32, Vec128{v.raw}); const auto sign = ShiftRight<15>(bits16); const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); const auto mantissa = bits16 & Set(du32, 0x3FF); const auto subnormal = BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * Set(df32, 1.0f / 16384 / 1024)); const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); return BitCast(df32, ShiftLeft<31>(sign) | bits32); } template HWY_API Vec128 PromoteUpperTo(D df32, Vec128 v) { const Full128 du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); } } // namespace detail template HWY_API VFromD PromoteTo(D d, Vec128 v) { const Half dh; VFromD ret; ret.v0 = PromoteTo(dh, LowerHalf(v)); ret.v1 = detail::PromoteUpperTo(dh, v); return ret; } // 4x promotion: 8-bit to 32-bit or 16-bit to 64-bit template HWY_API Vec256> PromoteTo(DW d, Vec64 v) { const Half dh; // 16-bit lanes for UI8->UI32, 32-bit lanes for UI16->UI64 const Rebind, decltype(d)> d2; const auto v_2x = PromoteTo(d2, v); Vec256> ret; ret.v0 = PromoteTo(dh, LowerHalf(v_2x)); ret.v1 = detail::PromoteUpperTo(dh, v_2x); return ret; } // 8x promotion: 8-bit to 64-bit template HWY_API Vec256> PromoteTo(DW d, Vec32 v) { const Half dh; const Repartition>, decltype(dh)> d4; // 32-bit lanes const auto v32 = PromoteTo(d4, v); Vec256> ret; ret.v0 = PromoteTo(dh, LowerHalf(v32)); ret.v1 = detail::PromoteUpperTo(dh, v32); return ret; } // ------------------------------ DemoteTo template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; } template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; } template HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); return Vec64{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; } template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; } template HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); return Vec64{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; } template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; } template HWY_API Vec128 DemoteTo(D di, Vec256 v) { const Vec64 lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; const Vec64 hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; return Combine(di, hi, lo); } template HWY_API Vec128 DemoteTo(D d16, Vec256 v) { const Half d16h; const Vec64 lo = DemoteTo(d16h, v.v0); const Vec64 hi = DemoteTo(d16h, v.v1); return Combine(d16, hi, lo); } template HWY_API Vec128 DemoteTo(D dbf16, Vec256 v) { const Half dbf16h; const Vec64 lo = DemoteTo(dbf16h, v.v0); const Vec64 hi = DemoteTo(dbf16h, v.v1); return Combine(dbf16, hi, lo); } // For already range-limited input [0, 255]. HWY_API Vec64 U8FromU32(Vec256 v) { const Full64 du8; const Full256 di32; // no unsigned DemoteTo return DemoteTo(du8, BitCast(di32, v)); } // ------------------------------ Truncations template HWY_API Vec32 TruncateTo(D /* tag */, Vec256 v) { return Vec32{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24)}; } template HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, 25)}; } template HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27)}; } template HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, 28)}; } template HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29)}; } template HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30)}; } // ------------------------------ ReorderDemote2To template HWY_API Vec256 ReorderDemote2To(DBF16 dbf16, Vec256 a, Vec256 b) { const RebindToUnsigned du16; return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); } template ), HWY_IF_SIGNED_V(V), HWY_IF_T_SIZE_ONE_OF_D(DN, (1 << 1) | (1 << 2) | (1 << 4)), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const Half dnh; VFromD demoted; demoted.v0 = DemoteTo(dnh, a); demoted.v1 = DemoteTo(dnh, b); return demoted; } template ) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const Half dnh; VFromD demoted; demoted.v0 = DemoteTo(dnh, a); demoted.v1 = DemoteTo(dnh, b); return demoted; } // ------------------------------ Convert i32 <=> f32 (Round) template > HWY_API Vec256 ConvertTo(DTo d, const Vec256 v) { const Half dh; Vec256 ret; ret.v0 = ConvertTo(dh, v.v0); ret.v1 = ConvertTo(dh, v.v1); return ret; } HWY_API Vec256 NearestInt(const Vec256 v) { return ConvertTo(Full256(), Round(v)); } // ================================================== MISC // ------------------------------ LoadMaskBits (TestBit) // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { const Half dh; MFromD ret; ret.m0 = LoadMaskBits(dh, bits); // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. // Both halves fit in one byte's worth of mask bits. constexpr size_t kBitsPerHalf = 16 / sizeof(TFromD); const uint8_t bits_upper[8] = {static_cast(bits[0] >> kBitsPerHalf)}; ret.m1 = LoadMaskBits(dh, bits_upper); return ret; } template HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { const Half dh; MFromD ret; ret.m0 = LoadMaskBits(dh, bits); constexpr size_t kLanesPerHalf = 16 / sizeof(TFromD); constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); return ret; } // ------------------------------ Mask // `p` points to at least 8 writable bytes. template , HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { const Half dh; StoreMaskBits(dh, mask.m0, bits); const uint8_t lo = bits[0]; StoreMaskBits(dh, mask.m1, bits); // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. // Both halves fit in one byte's worth of mask bits. constexpr size_t kBitsPerHalf = 16 / sizeof(T); bits[0] = static_cast(lo | (bits[0] << kBitsPerHalf)); return (kBitsPerHalf * 2 + 7) / 8; } template , HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { const Half dh; constexpr size_t kLanesPerHalf = 16 / sizeof(T); constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); StoreMaskBits(dh, mask.m0, bits); StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); return kBytesPerHalf * 2; } template > HWY_API size_t CountTrue(D d, const Mask256 m) { const Half dh; return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); } template > HWY_API bool AllFalse(D d, const Mask256 m) { const Half dh; return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); } template > HWY_API bool AllTrue(D d, const Mask256 m) { const Half dh; return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); } template > HWY_API size_t FindKnownFirstTrue(D d, const Mask256 mask) { const Half dh; const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known constexpr size_t kLanesPerHalf = 16 / sizeof(T); return lo >= 0 ? static_cast(lo) : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); } template > HWY_API intptr_t FindFirstTrue(D d, const Mask256 mask) { const Half dh; const intptr_t lo = FindFirstTrue(dh, mask.m0); constexpr int kLanesPerHalf = 16 / sizeof(T); if (lo >= 0) return lo; const intptr_t hi = FindFirstTrue(dh, mask.m1); return hi + (hi >= 0 ? kLanesPerHalf : 0); } template > HWY_API size_t FindKnownLastTrue(D d, const Mask256 mask) { const Half dh; const intptr_t hi = FindLastTrue(dh, mask.m1); // not known constexpr size_t kLanesPerHalf = 16 / sizeof(T); return hi >= 0 ? kLanesPerHalf + static_cast(hi) : FindKnownLastTrue(dh, mask.m0); } template > HWY_API intptr_t FindLastTrue(D d, const Mask256 mask) { const Half dh; constexpr int kLanesPerHalf = 16 / sizeof(T); const intptr_t hi = FindLastTrue(dh, mask.m1); return hi >= 0 ? kLanesPerHalf + hi : FindLastTrue(dh, mask.m0); } // ------------------------------ CompressStore template > HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, D d, T* HWY_RESTRICT unaligned) { const Half dh; const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); return count + count2; } // ------------------------------ CompressBlendedStore template > HWY_API size_t CompressBlendedStore(Vec256 v, const Mask256 m, D d, T* HWY_RESTRICT unaligned) { const Half dh; const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); return count + count2; } // ------------------------------ CompressBitsStore template > HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, D d, T* HWY_RESTRICT unaligned) { const Mask256 m = LoadMaskBits(d, bits); return CompressStore(v, m, d, unaligned); } // ------------------------------ Compress template HWY_API Vec256 Compress(const Vec256 v, const Mask256 mask) { const DFromV d; alignas(32) T lanes[32 / sizeof(T)] = {}; (void)CompressStore(v, mask, d, lanes); return Load(d, lanes); } // ------------------------------ CompressNot template HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { return Compress(v, Not(mask)); } // ------------------------------ CompressBlocksNot HWY_API Vec256 CompressBlocksNot(Vec256 v, Mask256 mask) { const Full128 dh; // Because the non-selected (mask=1) blocks are undefined, we can return the // input unless mask = 01, in which case we must bring down the upper block. return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; } // ------------------------------ CompressBits template HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { const Mask256 m = LoadMaskBits(DFromV(), bits); return Compress(v, m); } // ------------------------------ Expand template HWY_API Vec256 Expand(const Vec256 v, const Mask256 mask) { Vec256 ret; const Full256 d; const Half dh; alignas(32) T lanes[32 / sizeof(T)] = {}; Store(v, d, lanes); ret.v0 = Expand(v.v0, mask.m0); ret.v1 = Expand(LoadU(dh, lanes + CountTrue(dh, mask.m0)), mask.m1); return ret; } // ------------------------------ LoadExpand template HWY_API VFromD LoadExpand(MFromD mask, D d, const TFromD* HWY_RESTRICT unaligned) { return Expand(LoadU(d, unaligned), mask); } // ------------------------------ LoadInterleaved3/4 // Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. namespace detail { // Input: // 1 0 (<- first block of unaligned) // 3 2 // 5 4 // Output: // 3 0 // 4 1 // 5 2 template > HWY_API void LoadTransposedBlocks3(D d, const T* HWY_RESTRICT unaligned, Vec256& A, Vec256& B, Vec256& C) { const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); A = ConcatUpperLower(d, v32, v10); B = ConcatLowerUpper(d, v54, v10); C = ConcatUpperLower(d, v54, v32); } // Input (128-bit blocks): // 1 0 (first block of unaligned) // 3 2 // 5 4 // 7 6 // Output: // 4 0 (LSB of A) // 5 1 // 6 2 // 7 3 template > HWY_API void LoadTransposedBlocks4(D d, const T* HWY_RESTRICT unaligned, Vec256& vA, Vec256& vB, Vec256& vC, Vec256& vD) { const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); const Vec256 v76 = LoadU(d, unaligned + 3 * MaxLanes(d)); vA = ConcatLowerLower(d, v54, v10); vB = ConcatUpperUpper(d, v54, v10); vC = ConcatLowerLower(d, v76, v32); vD = ConcatUpperUpper(d, v76, v32); } } // namespace detail // ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. namespace detail { // Input (128-bit blocks): // 2 0 (LSB of i) // 3 1 // Output: // 1 0 // 3 2 template > HWY_API void StoreTransposedBlocks2(Vec256 i, Vec256 j, D d, T* HWY_RESTRICT unaligned) { const Vec256 out0 = ConcatLowerLower(d, j, i); const Vec256 out1 = ConcatUpperUpper(d, j, i); StoreU(out0, d, unaligned + 0 * MaxLanes(d)); StoreU(out1, d, unaligned + 1 * MaxLanes(d)); } // Input (128-bit blocks): // 3 0 (LSB of i) // 4 1 // 5 2 // Output: // 1 0 // 3 2 // 5 4 template > HWY_API void StoreTransposedBlocks3(Vec256 i, Vec256 j, Vec256 k, D d, T* HWY_RESTRICT unaligned) { const Vec256 out0 = ConcatLowerLower(d, j, i); const Vec256 out1 = ConcatUpperLower(d, i, k); const Vec256 out2 = ConcatUpperUpper(d, k, j); StoreU(out0, d, unaligned + 0 * MaxLanes(d)); StoreU(out1, d, unaligned + 1 * MaxLanes(d)); StoreU(out2, d, unaligned + 2 * MaxLanes(d)); } // Input (128-bit blocks): // 4 0 (LSB of i) // 5 1 // 6 2 // 7 3 // Output: // 1 0 // 3 2 // 5 4 // 7 6 template > HWY_API void StoreTransposedBlocks4(Vec256 i, Vec256 j, Vec256 k, Vec256 l, D d, T* HWY_RESTRICT unaligned) { // Write lower halves, then upper. const Vec256 out0 = ConcatLowerLower(d, j, i); const Vec256 out1 = ConcatLowerLower(d, l, k); StoreU(out0, d, unaligned + 0 * MaxLanes(d)); StoreU(out1, d, unaligned + 1 * MaxLanes(d)); const Vec256 out2 = ConcatUpperUpper(d, j, i); const Vec256 out3 = ConcatUpperUpper(d, l, k); StoreU(out2, d, unaligned + 2 * MaxLanes(d)); StoreU(out3, d, unaligned + 3 * MaxLanes(d)); } } // namespace detail // ------------------------------ WidenMulPairwiseAdd template > HWY_API Vec256 WidenMulPairwiseAdd(D32 d32, Vec256 a, Vec256 b) { const Half d32h; a.v0 = WidenMulPairwiseAdd(d32h, a.v0, b.v0); a.v1 = WidenMulPairwiseAdd(d32h, a.v1, b.v1); return a; } // ------------------------------ ReorderWidenMulAccumulate template > HWY_API Vec256 ReorderWidenMulAccumulate(D32 d32, Vec256 a, Vec256 b, Vec256 sum0, Vec256& sum1) { const Half d32h; sum0.v0 = ReorderWidenMulAccumulate(d32h, a.v0, b.v0, sum0.v0, sum1.v0); sum0.v1 = ReorderWidenMulAccumulate(d32h, a.v1, b.v1, sum0.v1, sum1.v1); return sum0; } // ------------------------------ RearrangeToOddPlusEven template HWY_API Vec256 RearrangeToOddPlusEven(Vec256 sum0, Vec256 sum1) { sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0); sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1); return sum0; } // ------------------------------ Reductions template > HWY_API Vec256 SumOfLanes(D d, const Vec256 v) { const Half dh; const Vec128 lo = SumOfLanes(dh, Add(v.v0, v.v1)); return Combine(d, lo, lo); } template > HWY_API T ReduceSum(D d, const Vec256 v) { const Half dh; return ReduceSum(dh, Add(v.v0, v.v1)); } template > HWY_API Vec256 MinOfLanes(D d, const Vec256 v) { const Half dh; const Vec128 lo = MinOfLanes(dh, Min(v.v0, v.v1)); return Combine(d, lo, lo); } template > HWY_API Vec256 MaxOfLanes(D d, const Vec256 v) { const Half dh; const Vec128 lo = MaxOfLanes(dh, Max(v.v0, v.v1)); return Combine(d, lo, lo); } // ------------------------------ Lt128 template > HWY_INLINE Mask256 Lt128(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Lt128(dh, a.v0, b.v0); ret.m1 = Lt128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Lt128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Lt128Upper(dh, a.v0, b.v0); ret.m1 = Lt128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Eq128(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Eq128(dh, a.v0, b.v0); ret.m1 = Eq128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Eq128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Eq128Upper(dh, a.v0, b.v0); ret.m1 = Eq128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Ne128(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Ne128(dh, a.v0, b.v0); ret.m1 = Ne128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Ne128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Ne128Upper(dh, a.v0, b.v0); ret.m1 = Ne128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Min128(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Min128(dh, a.v0, b.v0); ret.v1 = Min128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Max128(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Max128(dh, a.v0, b.v0); ret.v1 = Max128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Min128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Min128Upper(dh, a.v0, b.v0); ret.v1 = Min128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Max128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Max128Upper(dh, a.v0, b.v0); ret.v1 = Max128Upper(dh, a.v1, b.v1); return ret; } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();