From 6bf0a5cb5034a7e684dcc3500e841785237ce2dd Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 19:32:43 +0200 Subject: Adding upstream version 1:115.7.0. Signed-off-by: Daniel Baumann --- third_party/highway/hwy/ops/wasm_256-inl.h | 2003 ++++++++++++++++++++++++++++ 1 file changed, 2003 insertions(+) create mode 100644 third_party/highway/hwy/ops/wasm_256-inl.h (limited to 'third_party/highway/hwy/ops/wasm_256-inl.h') diff --git a/third_party/highway/hwy/ops/wasm_256-inl.h b/third_party/highway/hwy/ops/wasm_256-inl.h new file mode 100644 index 0000000000..aa62f05e00 --- /dev/null +++ b/third_party/highway/hwy/ops/wasm_256-inl.h @@ -0,0 +1,2003 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit WASM vectors and operations. Experimental. +// External include guard in highway.h - see comment there. + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/wasm_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +class Vec256 { + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Vec128 v0; + Vec128 v1; +}; + +template +struct Mask256 { + Mask128 m0; + Mask128 m1; +}; + +// ------------------------------ BitCast + +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = BitCast(dh, v.v0); + ret.v1 = BitCast(dh, v.v1); + return ret; +} + +// ------------------------------ Zero + +template +HWY_API Vec256 Zero(Full256 d) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Zero(dh); + return ret; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec256 Set(Full256 d, const T2 t) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Set(dh, static_cast(t)); + return ret; +} + +template +HWY_API Vec256 Undefined(Full256 d) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Undefined(dh); + return ret; +} + +template +Vec256 Iota(const Full256 d, const T2 first) { + const Half dh; + Vec256 ret; + ret.v0 = Iota(dh, first); + // NB: for floating types the gap between parts might be a bit uneven. + ret.v1 = Iota(dh, AddWithWraparound(hwy::IsFloatTag(), + static_cast(first), Lanes(dh))); + return ret; +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec256 operator+(Vec256 a, const Vec256 b) { + a.v0 += b.v0; + a.v1 += b.v1; + return a; +} + +template +HWY_API Vec256 operator-(Vec256 a, const Vec256 b) { + a.v0 -= b.v0; + a.v1 -= b.v1; + return a; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +template +HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { + a.v0 = SaturatedAdd(a.v0, b.v0); + a.v1 = SaturatedAdd(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { + a.v0 = SaturatedSub(a.v0, b.v0); + a.v1 = SaturatedSub(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { + a.v0 = AverageRound(a.v0, b.v0); + a.v1 = AverageRound(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Abs(Vec256 v) { + v.v0 = Abs(v.v0); + v.v1 = Abs(v.v1); + return v; +} + +// ------------------------------ Shift lanes by constant #bits + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + v.v0 = ShiftLeft(v.v0); + v.v1 = ShiftLeft(v.v1); + return v; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + v.v0 = ShiftRight(v.v0); + v.v1 = ShiftRight(v.v1); + return v; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec256 RotateRight(const Vec256 v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +template +HWY_API Vec256 ShiftLeftSame(Vec256 v, const int bits) { + v.v0 = ShiftLeftSame(v.v0, bits); + v.v1 = ShiftLeftSame(v.v1, bits); + return v; +} + +template +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + v.v0 = ShiftRightSame(v.v0, bits); + v.v1 = ShiftRightSame(v.v1, bits); + return v; +} + +// ------------------------------ Min, Max +template +HWY_API Vec256 Min(Vec256 a, const Vec256 b) { + a.v0 = Min(a.v0, b.v0); + a.v1 = Min(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Max(Vec256 a, const Vec256 b) { + a.v0 = Max(a.v0, b.v0); + a.v1 = Max(a.v1, b.v1); + return a; +} +// ------------------------------ Integer multiplication + +template +HWY_API Vec256 operator*(Vec256 a, const Vec256 b) { + a.v0 *= b.v0; + a.v1 *= b.v1; + return a; +} + +template +HWY_API Vec256 MulHigh(Vec256 a, const Vec256 b) { + a.v0 = MulHigh(a.v0, b.v0); + a.v1 = MulHigh(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 MulFixedPoint15(Vec256 a, const Vec256 b) { + a.v0 = MulFixedPoint15(a.v0, b.v0); + a.v1 = MulFixedPoint15(a.v1, b.v1); + return a; +} + +// Cannot use MakeWide because that returns uint128_t for uint64_t, but we want +// uint64_t. +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} + +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} + +// ------------------------------ Negate +template +HWY_API Vec256 Neg(Vec256 v) { + v.v0 = Neg(v.v0); + v.v1 = Neg(v.v1); + return v; +} + +// ------------------------------ Floating-point division +template +HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { + a.v0 /= b.v0; + a.v1 /= b.v1; + return a; +} + +// Approximate reciprocal +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + const Vec256 one = Set(Full256(), 1.0f); + return one / v; +} + +// Absolute value of difference. +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfma? + return mul * x + add; +} + +// Returns add - mul * x +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { + // TODO(eustas): replace, when implemented in WASM. + return add - mul * x; +} + +// Returns mul * x - sub +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfms? + return mul * x - sub; +} + +// Returns -mul * x - sub +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { + // TODO(eustas): replace, when implemented in WASM. + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec256 Sqrt(Vec256 v) { + v.v0 = Sqrt(v.v0); + v.v1 = Sqrt(v.v1); + return v; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec256 one = Set(Full256(), 1.0f); + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +HWY_API Vec256 Round(Vec256 v) { + v.v0 = Round(v.v0); + v.v1 = Round(v.v1); + return v; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(Vec256 v) { + v.v0 = Trunc(v.v0); + v.v1 = Trunc(v.v1); + return v; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(Vec256 v) { + v.v0 = Ceil(v.v0); + v.v1 = Ceil(v.v1); + return v; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(Vec256 v) { + v.v0 = Floor(v.v0); + v.v1 = Floor(v.v1); + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask256 IsNaN(const Vec256 v) { + return v != v; +} + +template +HWY_API Mask256 IsInf(const Vec256 v) { + const Full256 d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + const Full256 d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256{Mask128{m.m0.raw}, Mask128{m.m1.raw}}; +} + +template +HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask256 operator==(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator==(a.v0, b.v0); + m.m1 = operator==(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator!=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator!=(a.v0, b.v0); + m.m1 = operator!=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<(a.v0, b.v0); + m.m1 = operator<(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>(a.v0, b.v0); + m.m1 = operator>(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<=(a.v0, b.v0); + m.m1 = operator<=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>=(a.v0, b.v0); + m.m1 = operator>=(a.v1, b.v1); + return m; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask256 FirstN(const Full256 d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ================================================== LOGICAL + +template +HWY_API Vec256 Not(Vec256 v) { + v.v0 = Not(v.v0); + v.v1 = Not(v.v1); + return v; +} + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + a.v0 = And(a.v0, b.v0); + a.v1 = And(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + not_mask.v0 = AndNot(not_mask.v0, mask.v0); + not_mask.v1 = AndNot(not_mask.v1, mask.v1); + return not_mask; +} + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + a.v0 = Or(a.v0, b.v0); + a.v1 = Or(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + a.v0 = Xor(a.v0, b.v0); + a.v1 = Xor(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { + return Xor(x1, Xor(x2, x3)); +} + +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { + return Or(o1, Or(o2, o3)); +} + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Full256()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Full256()), sign)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + Mask256 m; + m.m0 = MaskFromVec(v.v0); + m.m1 = MaskFromVec(v.v1); + return m; +} + +template +HWY_API Vec256 VecFromMask(Full256 d, Mask256 m) { + const Half dh; + Vec256 v; + v.v0 = VecFromMask(dh, m.m0); + v.v1 = VecFromMask(dh, m.m1); + return v; +} + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); + yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); + return yes; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); + v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); + return v; +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + return IfThenZeroElse(v < Zero(Full256()), v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { + v.v0 = operator<<(v.v0, bits.v0); + v.v1 = operator<<(v.v1, bits.v1); + return v; +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { + v.v0 = operator>>(v.v0, bits.v0); + v.v1 = operator>>(v.v1, bits.v1); + return v; +} + +// ------------------------------ BroadcastSignBit (compare, VecFromMask) + +template +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight(v); +} +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const Full256 d; + return VecFromMask(d, v < Zero(d)); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec256 Load(Full256 d, const T* HWY_RESTRICT aligned) { + const Half dh; + Vec256 ret; + ret.v0 = Load(dh, aligned); + ret.v1 = Load(dh, aligned + Lanes(dh)); + return ret; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// LoadU == Load. +template +HWY_API Vec256 LoadU(Full256 d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +template +HWY_API Vec256 LoadDup128(Full256 d, const T* HWY_RESTRICT p) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Load(dh, p); + return ret; +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec256 v, Full256 d, T* HWY_RESTRICT aligned) { + const Half dh; + Store(v.v0, dh, aligned); + Store(v.v1, dh, aligned + Lanes(dh)); +} + +// StoreU == Store. +template +HWY_API void StoreU(Vec256 v, Full256 d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Stream +template +HWY_API void Stream(Vec256 v, Full256 d, T* HWY_RESTRICT aligned) { + // Same as aligned stores. + Store(v, d, aligned); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Full256(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec256 GatherOffset(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256(), offset_lanes); + + alignas(32) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec256 GatherIndex(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) Index index_lanes[N]; + Store(index, Full256(), index_lanes); + + alignas(32) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, Full256(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + Full256 d; + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return v.v0; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return v.v0; +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec256 ShiftLeftBytes(Full256 d, Vec256 v) { + const Half dh; + v.v0 = ShiftLeftBytes(dh, v.v0); + v.v1 = ShiftLeftBytes(dh, v.v1); + return v; +} + +template +HWY_API Vec256 ShiftLeftBytes(Vec256 v) { + return ShiftLeftBytes(Full256(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(Full256(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec256 ShiftRightBytes(Full256 d, Vec256 v) { + const Half dh; + v.v0 = ShiftRightBytes(dh, v.v0); + v.v1 = ShiftRightBytes(dh, v.v1); + return v; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template +HWY_API Vec128 UpperHalf(Full128 /* tag */, const Vec256 v) { + return v.v1; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full256 d, V hi, V lo) { + const Half dh; + hi.v0 = CombineShiftRightBytes(dh, hi.v0, lo.v0); + hi.v1 = CombineShiftRightBytes(dh, hi.v1, lo.v1); + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + Vec256 ret; + ret.v0 = Broadcast(v.v0); + ret.v1 = Broadcast(v.v1); + return ret; +} + +// ------------------------------ TableLookupBytes + +// Both full +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, Vec256 from) { + from.v0 = TableLookupBytes(bytes.v0, from.v0); + from.v1 = TableLookupBytes(bytes.v1, from.v1); + return from; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(const Vec256 bytes, + const Vec128 from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(Full128(), tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(const Vec128 bytes, + const Vec256 from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by wasm_128. + +template +HWY_API VI TableLookupBytesOr0(const V bytes, VI from) { + // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Hard-coded shuffles + +template +HWY_API Vec256 Shuffle01(Vec256 v) { + v.v0 = Shuffle01(v.v0); + v.v1 = Shuffle01(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2301(Vec256 v) { + v.v0 = Shuffle2301(v.v0); + v.v1 = Shuffle2301(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle1032(Vec256 v) { + v.v0 = Shuffle1032(v.v0); + v.v1 = Shuffle1032(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0321(Vec256 v) { + v.v0 = Shuffle0321(v.v0); + v.v1 = Shuffle0321(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2103(Vec256 v) { + v.v0 = Shuffle2103(v.v0); + v.v1 = Shuffle2103(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0123(Vec256 v) { + v.v0 = Shuffle0123(v.v0); + v.v1 = Shuffle0123(v.v1); + return v; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 Shuffle2301(Vec256 a, const Vec256 b) { + a.v0 = Shuffle2301(a.v0, b.v0); + a.v1 = Shuffle2301(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 Shuffle1230(Vec256 a, const Vec256 b) { + a.v0 = Shuffle1230(a.v0, b.v0); + a.v1 = Shuffle1230(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 Shuffle3012(Vec256 a, const Vec256 b) { + a.v0 = Shuffle3012(a.v0, b.v0); + a.v1 = Shuffle3012(a.v1, b.v1); + return a; +} + +} // namespace detail + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices256 { + __v128_u i0; + __v128_u i1; +}; + +template +HWY_API Indices256 IndicesFromVec(Full256 /* tag */, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + Indices256 ret; + ret.i0 = vec.v0.raw; + ret.i1 = vec.v1.raw; + return ret; +} + +template +HWY_API Indices256 SetTableIndices(Full256 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(const Vec256 v, Indices256 idx) { + using TU = MakeUnsigned; + const Full128 dh; + const Full128 duh; + constexpr size_t kLanesPerHalf = 16 / sizeof(TU); + + const Vec128 vi0{idx.i0}; + const Vec128 vi1{idx.i1}; + const Vec128 mask = Set(duh, static_cast(kLanesPerHalf - 1)); + const Vec128 vmod0 = vi0 & mask; + const Vec128 vmod1 = vi1 & mask; + // If ANDing did not change the index, it is for the lower half. + const Mask128 is_lo0 = RebindMask(dh, vi0 == vmod0); + const Mask128 is_lo1 = RebindMask(dh, vi1 == vmod1); + const Indices128 mod0 = IndicesFromVec(dh, vmod0); + const Indices128 mod1 = IndicesFromVec(dh, vmod1); + + Vec256 ret; + ret.v0 = IfThenElse(is_lo0, TableLookupLanes(v.v0, mod0), + TableLookupLanes(v.v1, mod0)); + ret.v1 = IfThenElse(is_lo1, TableLookupLanes(v.v0, mod1), + TableLookupLanes(v.v1, mod1)); + return ret; +} + +template +HWY_API Vec256 TableLookupLanesOr0(Vec256 v, Indices256 idx) { + // The out of bounds behavior will already zero lanes. + return TableLookupLanesOr0(v, idx); +} + +// ------------------------------ Reverse +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order + ret.v0 = Reverse(dh, v.v1); + return ret; +} + +// ------------------------------ Reverse2 +template +HWY_API Vec256 Reverse2(Full256 d, Vec256 v) { + const Half dh; + v.v0 = Reverse2(dh, v.v0); + v.v1 = Reverse2(dh, v.v1); + return v; +} + +// ------------------------------ Reverse4 + +// Each block has only 2 lanes, so swap blocks and their lanes. +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse2(dh, v.v1); // swapped + ret.v1 = Reverse2(dh, v.v0); + return ret; +} + +template +HWY_API Vec256 Reverse4(Full256 d, Vec256 v) { + const Half dh; + v.v0 = Reverse4(dh, v.v0); + v.v1 = Reverse4(dh, v.v1); + return v; +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 /* tag */, Vec256 /* v */) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// Each block has only 4 lanes, so swap blocks and their lanes. +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse4(dh, v.v1); // swapped + ret.v1 = Reverse4(dh, v.v0); + return ret; +} + +template // 1 or 2 bytes +HWY_API Vec256 Reverse8(Full256 d, Vec256 v) { + const Half dh; + v.v0 = Reverse8(dh, v.v0); + v.v1 = Reverse8(dh, v.v1); + return v; +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + a.v0 = InterleaveLower(a.v0, b.v0); + a.v1 = InterleaveLower(a.v1, b.v1); + return a; +} + +// wasm_128 already defines a template with D, V, V args. + +// ------------------------------ InterleaveUpper (UpperHalf) + +template > +HWY_API V InterleaveUpper(Full256 d, V a, V b) { + const Half dh; + a.v0 = InterleaveUpper(dh, a.v0, b.v0); + a.v1 = InterleaveUpper(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(Vec256 a, Vec256 b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) +template +HWY_API Vec256 Combine(Full256 /* d */, Vec128 hi, Vec128 lo) { + Vec256 ret; + ret.v1 = hi; + ret.v0 = lo; + return ret; +} + +// ------------------------------ ZeroExtendVector (Combine) +template +HWY_API Vec256 ZeroExtendVector(Full256 d, Vec128 lo) { + const Half dh; + return Combine(d, Zero(dh), lo); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API Vec256 ConcatLowerLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatUpperLower +template +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatOdd +template +HWY_API Vec256 ConcatOdd(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); + ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ ConcatEven +template +HWY_API Vec256 ConcatEven(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatEven(dh, lo.v1, lo.v0); + ret.v1 = ConcatEven(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ DupEven +template +HWY_API Vec256 DupEven(Vec256 v) { + v.v0 = DupEven(v.v0); + v.v1 = DupEven(v.v1); + return v; +} + +// ------------------------------ DupOdd +template +HWY_API Vec256 DupOdd(Vec256 v) { + v.v0 = DupOdd(v.v0); + v.v1 = DupOdd(v.v1); + return v; +} + +// ------------------------------ OddEven +template +HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { + a.v0 = OddEven(a.v0, b.v0); + a.v1 = OddEven(a.v1, b.v1); + return a; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + odd.v0 = even.v0; + return odd; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + Vec256 ret; + ret.v0 = v.v1; // swapped order + ret.v1 = v.v0; + return ret; +} + +// ------------------------------ ReverseBlocks +template +HWY_API Vec256 ReverseBlocks(Full256 /* tag */, const Vec256 v) { + return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +namespace detail { + +// Unsigned: zero-extend. +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u64x2_extend_high_u32x4(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_high_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_extend_high_i8x16(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_extend_high_i16x8(wasm_i16x8_extend_high_i8x16(v.raw))}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_extend_high_i16x8(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i64x2_extend_high_i32x4(v.raw)}; +} + +HWY_API Vec128 PromoteUpperTo(Full128 dd, + const Vec128 v) { + // There is no wasm_f64x2_convert_high_i32x4. + const Full64 di32h; + return PromoteTo(dd, UpperHalf(di32h, v)); +} + +HWY_API Vec128 PromoteUpperTo(Full128 df32, + const Vec128 v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteUpperTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +HWY_API Vec128 PromoteUpperTo(Full128 df32, + const Vec128 v) { + const Full128 du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +} // namespace detail + +template +HWY_API Vec256 PromoteTo(Full256 d, const Vec128 v) { + const Half dh; + Vec256 ret; + ret.v0 = PromoteTo(dh, LowerHalf(v)); + ret.v1 = detail::PromoteUpperTo(dh, v); + return ret; +} + +// This is the only 4x promotion from 8 to 32-bit. +template +HWY_API Vec256 PromoteTo(Full256 d, const Vec64 v) { + const Half dh; + const Rebind, decltype(d)> d2; // 16-bit lanes + const auto v16 = PromoteTo(d2, v); + Vec256 ret; + ret.v0 = PromoteTo(dh, LowerHalf(v16)); + ret.v1 = detail::PromoteUpperTo(dh, v16); + return ret; +} + +// ------------------------------ DemoteTo + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 di, const Vec256 v) { + const Vec64 lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64 hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +HWY_API Vec128 DemoteTo(Full128 d16, + const Vec256 v) { + const Half d16h; + const Vec64 lo = DemoteTo(d16h, v.v0); + const Vec64 hi = DemoteTo(d16h, v.v1); + return Combine(d16, hi, lo); +} + +HWY_API Vec128 DemoteTo(Full128 dbf16, + const Vec256 v) { + const Half dbf16h; + const Vec64 lo = DemoteTo(dbf16h, v.v0); + const Vec64 hi = DemoteTo(dbf16h, v.v1); + return Combine(dbf16, hi, lo); +} + +// For already range-limited input [0, 255]. +HWY_API Vec64 U8FromU32(const Vec256 v) { + const Full64 du8; + const Full256 di32; // no unsigned DemoteTo + return DemoteTo(du8, BitCast(di32, v)); +} + +// ------------------------------ Truncations + +HWY_API Vec32 TruncateTo(Full32 /* tag */, + const Vec256 v) { + return Vec32{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, + 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, + 24)}; +} + +HWY_API Vec64 TruncateTo(Full64 /* tag */, + const Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, + 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, + 25)}; +} + +HWY_API Vec128 TruncateTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, + 9, 10, 11, 16, 17, 18, 19, 24, 25, + 26, 27)}; +} + +HWY_API Vec64 TruncateTo(Full64 /* tag */, + const Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, + 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, + 28)}; +} + +HWY_API Vec128 TruncateTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, + 9, 12, 13, 16, 17, 20, 21, 24, 25, + 28, 29)}; +} + +HWY_API Vec128 TruncateTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, + 10, 12, 14, 16, 18, 20, 22, 24, 26, + 28, 30)}; +} + +// ------------------------------ ReorderDemote2To +HWY_API Vec256 ReorderDemote2To(Full256 dbf16, + Vec256 a, Vec256 b) { + const RebindToUnsigned du16; + return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); +} + +HWY_API Vec256 ReorderDemote2To(Full256 d16, + Vec256 a, Vec256 b) { + const Half d16h; + Vec256 demoted; + demoted.v0 = DemoteTo(d16h, a); + demoted.v1 = DemoteTo(d16h, b); + return demoted; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template +HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = ConvertTo(dh, v.v0); + ret.v1 = ConvertTo(dh, v.v1); + return ret; +} + +HWY_API Vec256 NearestInt(const Vec256 v) { + return ConvertTo(Full256(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template // 4 or 8 bytes +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + const Half dh; + Mask256 ret; + ret.m0 = LoadMaskBits(dh, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + const uint8_t bits_upper[8] = {static_cast(bits[0] >> kBitsPerHalf)}; + ret.m1 = LoadMaskBits(dh, bits_upper); + return ret; +} + +template // 1 or 2 bytes +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + const Half dh; + Mask256 ret; + ret.m0 = LoadMaskBits(dh, bits); + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); + return ret; +} + +// ------------------------------ Mask + +// `p` points to at least 8 writable bytes. +template // 4 or 8 bytes +HWY_API size_t StoreMaskBits(const Full256 d, const Mask256 mask, + uint8_t* bits) { + const Half dh; + StoreMaskBits(dh, mask.m0, bits); + const uint8_t lo = bits[0]; + StoreMaskBits(dh, mask.m1, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + bits[0] = static_cast(lo | (bits[0] << kBitsPerHalf)); + return (kBitsPerHalf * 2 + 7) / 8; +} + +template // 1 or 2 bytes +HWY_API size_t StoreMaskBits(const Full256 d, const Mask256 mask, + uint8_t* bits) { + const Half dh; + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + StoreMaskBits(dh, mask.m0, bits); + StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); + return kBytesPerHalf * 2; +} + +template +HWY_API size_t CountTrue(const Full256 d, const Mask256 m) { + const Half dh; + return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); +} + +template +HWY_API bool AllFalse(const Full256 d, const Mask256 m) { + const Half dh; + return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); +} + +template +HWY_API bool AllTrue(const Full256 d, const Mask256 m) { + const Half dh; + return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? static_cast(lo) + : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); + const intptr_t hi = FindFirstTrue(dh, mask.m1); + if (lo < 0 && hi < 0) return lo; + constexpr int kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? lo : hi + kLanesPerHalf; +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const Vec256 v, const Mask256 mask, + Full256 d, T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); + const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const Vec256 v, const Mask256 m, + Full256 d, T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); + const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(const Vec256 v, + const uint8_t* HWY_RESTRICT bits, Full256 d, + T* HWY_RESTRICT unaligned) { + const Mask256 m = LoadMaskBits(d, bits); + return CompressStore(v, m, d, unaligned); +} + +// ------------------------------ Compress + +template +HWY_API Vec256 Compress(const Vec256 v, const Mask256 mask) { + const Full256 d; + alignas(32) T lanes[32 / sizeof(T)] = {}; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +// ------------------------------ CompressNot +template +HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + const Full128 dh; + // Because the non-selected (mask=1) blocks are undefined, we can return the + // input unless mask = 01, in which case we must bring down the upper block. + return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; +} + +// ------------------------------ CompressBits + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + const Mask256 m = LoadMaskBits(Full256(), bits); + return Compress(v, m); +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C, + Vec256& D) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + const Vec256 v76 = LoadU(d, unaligned + 3 * N); + + A = ConcatLowerLower(d, v54, v10); + B = ConcatUpperUpper(d, v54, v10); + C = ConcatLowerLower(d, v76, v32); + D = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(const Vec256 i, const Vec256 j, + const Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(const Vec256 i, const Vec256 j, + const Vec256 k, Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(const Vec256 i, const Vec256 j, + const Vec256 k, const Vec256 l, + Full256 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ ReorderWidenMulAccumulate +template +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 d, Vec256 a, + Vec256 b, Vec256 sum0, + Vec256& sum1) { + const Half dh; + sum0.v0 = ReorderWidenMulAccumulate(dh, a.v0, b.v0, sum0.v0, sum1.v0); + sum0.v1 = ReorderWidenMulAccumulate(dh, a.v1, b.v1, sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec256 RearrangeToOddPlusEven(Vec256 sum0, Vec256 sum1) { + sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0); + sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ Reductions + +template +HWY_API Vec256 SumOfLanes(Full256 d, const Vec256 v) { + const Half dh; + const Vec128 lo = SumOfLanes(dh, Add(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +template +HWY_API Vec256 MinOfLanes(Full256 d, const Vec256 v) { + const Half dh; + const Vec128 lo = MinOfLanes(dh, Min(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +template +HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 v) { + const Half dh; + const Vec128 lo = MaxOfLanes(dh, Max(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask256 Lt128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128(dh, a.v0, b.v0); + ret.m1 = Lt128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Lt128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128Upper(dh, a.v0, b.v0); + ret.m1 = Lt128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Eq128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128(dh, a.v0, b.v0); + ret.m1 = Eq128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Eq128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128Upper(dh, a.v0, b.v0); + ret.m1 = Eq128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Ne128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128(dh, a.v0, b.v0); + ret.m1 = Ne128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Ne128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128Upper(dh, a.v0, b.v0); + ret.m1 = Ne128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Min128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128(dh, a.v0, b.v0); + ret.v1 = Min128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Max128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128(dh, a.v0, b.v0); + ret.v1 = Max128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Min128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128Upper(dh, a.v0, b.v0); + ret.v1 = Min128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Max128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128Upper(dh, a.v0, b.v0); + ret.v1 = Max128Upper(dh, a.v1, b.v1); + return ret; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); -- cgit v1.2.3