From 6bf0a5cb5034a7e684dcc3500e841785237ce2dd Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 19:32:43 +0200 Subject: Adding upstream version 1:115.7.0. Signed-off-by: Daniel Baumann --- third_party/highway/hwy/ops/x86_128-inl.h | 7432 +++++++++++++++++++++++++++++ 1 file changed, 7432 insertions(+) create mode 100644 third_party/highway/hwy/ops/x86_128-inl.h (limited to 'third_party/highway/hwy/ops/x86_128-inl.h') diff --git a/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/highway/hwy/ops/x86_128-inl.h new file mode 100644 index 0000000000..ba8d581984 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_128-inl.h @@ -0,0 +1,7432 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_GCC_ACTUAL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's emmintrin.h - see +// https://github.com/google/highway/issues/710 and pull/902 +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include +#include +#if HWY_TARGET == HWY_SSSE3 +#include // SSSE3 +#else +#include // SSE4 +#include // CLMUL +#endif +#include +#include +#include // memcpy + +#include "hwy/ops/shared-inl.h" + +#if HWY_IS_MSAN +#include +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw128 { + using type = __m128i; +}; +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +#if HWY_TARGET <= HWY_AVX3 + +namespace detail { + +// Template arg: sizeof(lane type) +template +struct RawMask128 {}; +template <> +struct RawMask128<1> { + using type = __mmask16; +}; +template <> +struct RawMask128<2> { + using type = __mmask8; +}; +template <> +struct RawMask128<4> { + using type = __mmask8; +}; +template <> +struct RawMask128<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +struct Mask128 { + using Raw = typename detail::RawMask128::type; + + static Mask128 FromBits(uint64_t mask_bits) { + return Mask128{static_cast(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 or below + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_si128()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_ps()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_pd()}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { + return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { + return Vec128{_mm_set1_epi32(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { + return Vec128{ + _mm_set1_epi64x(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{_mm_set1_epi32(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { + return Vec128{ + _mm_set1_epi64x(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{_mm_set1_ps(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const double t) { + return Vec128{_mm_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec128{_mm_undefined_si128()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_ps()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFF); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFFFF); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw)); +} +template +HWY_API float GetLane(const Vec128 v) { + return _mm_cvtss_f32(v.raw); +} +template +HWY_API uint64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) uint64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return static_cast(_mm_cvtsi128_si64(v.raw)); +#endif +} +template +HWY_API int64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) int64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return _mm_cvtsi128_si64(v.raw); +#endif +} +template +HWY_API double GetLane(const Vec128 v) { + return _mm_cvtsd_f64(v.raw); +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_si128(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{_mm_andnot_si128(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_TARGET <= HWY_AVX3 + const __m128i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Neg(hwy::FloatTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::NonFloatTag /*tag*/, const Vec128 v) { + return Zero(DFromV()) - v; +} + +} // namespace detail + +template +HWY_INLINE Vec128 Neg(const Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (reaches breakpoint) + const auto zero = Zero(DFromV()); + return Vec128{_mm_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec128{_mm_abs_epi8(v.raw)}; +#endif +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi16(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(DFromV(), mask); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(DFromV(), mask); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m128i out = _mm_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, VFromD{out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(DFromV()), sign)); +#endif +} + +// ================================================== MASK + +namespace detail { + +template +HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) { + // Workaround for MSAN not marking compressstore as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// ------------------------------ Mask logical + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS + +namespace detail { + +template +HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Not(const Mask128 m) { + // Flip only the valid bits. + // TODO(janwas): use _knot intrinsics if N >= 8. + return Xor(m, Mask128::FromBits((1ull << N) - 1)); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 or below + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Simd /* tag */, + const Mask128 v) { + return Vec128{v.raw}; +} + +#if HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const auto vmask = VecFromMask(DFromV(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); +} + +#else // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + return Vec128{_mm_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80. +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {1, 0, 7, 6}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0302, 0x0100, 0x0f0e, 0x0d0c}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0, 3, 6, 5}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0100, 0x0706, 0x0d0c, 0x0b0a}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {2, 1, 4, 7}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0504, 0x0302, 0x0908, 0x0f0e}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi8(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi16(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi32(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi64(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(Simd /* tag */, + const Mask128 v) { + return VecFromMask(v); +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(Simd(), VecFromMask(d, m))); +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const Simd d; + const RepartitionToNarrow d32; + const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of num! + const uint64_t bits = (num > 255) ? all : _bzhi_u64(all, num); + return Mask128::FromBits(bits); +#else + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +#endif +} + +template +using MFromD = decltype(FirstN(D(), 0)); + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{_mm_load_si128(reinterpret_cast(aligned))}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; +} + +template +HWY_API Vec128 LoadU(Full128 /* tag */, const T* HWY_RESTRICT p) { + return Vec128{_mm_loadu_si128(reinterpret_cast(p))}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; +} + +template +HWY_API Vec64 Load(Full64 /* tag */, const T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_loadl_epi64(reinterpret_cast(p))}; +#endif +} + +HWY_API Vec128 Load(Full64 /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec128{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec128{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif +} + +HWY_API Vec64 Load(Full64 /* tag */, + const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_load_sd(p)}; +#endif +} + +HWY_API Vec128 Load(Full32 /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec128{v}; +#else + return Vec128{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + constexpr size_t kSize = sizeof(T) * N; +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes(p, &v); // not same size + return Vec128{v}; +#else + int32_t bits = 0; + CopyBytes(p, &bits); // not same size + return Vec128{_mm_cvtsi32_si128(bits)}; +#endif +} + +// For < 128 bit, LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +HWY_API Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, + Simd /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, + Simd /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_pd(m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return Vec128{_mm_maskload_epi32(p_p, m.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return Vec128{_mm_maskload_epi64(p_p, m.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const float* HWY_RESTRICT p) { + const Vec128 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec128{_mm_maskload_ps(p, mi.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const double* HWY_RESTRICT p) { + const Vec128 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec128{_mm_maskload_pd(p, mi.raw)}; +} + +// There is no maskload_epi8/16, so blend instead. +template // 1 or 2 bytes +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#else // <= SSE4 + +// Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#endif + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template +HWY_API void Store(Vec64 v, Full64 /* tag */, T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec128 v, Full64 /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec64 v, Full64 /* tag */, + double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); // not same size +} +HWY_API void Store(const Vec128 v, Full32 /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ BlendedStore + +namespace detail { + +// There is no maskload_epi8/16 with which we could safely implement +// BlendedStore. Manual blending is also unsafe because loading a full vector +// that crosses the array end causes asan faults. Resort to scalar code; the +// caller should instead use memcpy, assuming m is FirstN(d, n). +template +HWY_API void ScalarMaskedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + const RebindToSigned di; // for testing mask if T=bfloat16_t. + using TI = TFromD; + alignas(16) TI buf[N]; + alignas(16) TI mask[N]; + Store(BitCast(di, v), di, buf); + Store(BitCast(di, VecFromMask(d, m)), di, mask); + for (size_t i = 0; i < N; ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi8(p, m.raw, v.raw); +} +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd, float* HWY_RESTRICT p) { + _mm_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd, double* HWY_RESTRICT p) { + _mm_mask_storeu_pd(p, m.raw, v.raw); +} + +#elif HWY_TARGET == HWY_AVX2 + +template // 1 or 2 bytes +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + detail::ScalarMaskedStore(v, m, d, p); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd d, float* HWY_RESTRICT p) { + using T = float; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + const Vec128, N> mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm_maskstore_ps(p, mi.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd d, double* HWY_RESTRICT p) { + using T = double; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + const Vec128, N> mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm_maskstore_pd(p, mi.raw, v.raw); +} + +#else // <= SSE4 + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). + detail::ScalarMaskedStore(v, m, d, p); +} + +#endif // SSE4 + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epu32(a.raw, b.raw)}; +} + +#if HWY_TARGET == HWY_SSSE3 + +template // N=1 or 2 +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Set(Simd(), + static_cast(GetLane(a)) * GetLane(b)); +} +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) int32_t a_lanes[4]; + alignas(16) int32_t b_lanes[4]; + const Full128 di32; + Store(a, di32, a_lanes); + Store(b, di32, b_lanes); + alignas(16) int64_t mul[2]; + mul[0] = static_cast(a_lanes[0]) * b_lanes[0]; + mul[1] = static_cast(a_lanes[2]) * b_lanes[2]; + return Load(Full128(), mul); +} + +#else // HWY_TARGET == HWY_SSSE3 + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epi32(a.raw, b.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // Not as inefficient as it looks: _mm_mullo_epi32 has 10 cycle latency. + // 64-bit right shift would also work but also needs port 5, so no benefit. + // Notation: x=don't care, z=0. + const __m128i a_x3x1 = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x2x0 = MulEven(a, b); + const __m128i b_x3x1 = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x3x1 = + MulEven(Vec128{a_x3x1}, Vec128{b_x3x1}); + // We could _mm_slli_epi64 by 32 to get 3z1z and OR with z2z0, but generating + // the latter requires one more instruction or a constant. + const __m128i mul_20 = + _mm_shuffle_epi32(mullo_x2x0.raw, _MM_SHUFFLE(2, 0, 2, 0)); + const __m128i mul_31 = + _mm_shuffle_epi32(mullo_x3x1.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(mul_20, mul_31)}; +#else + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + // Same as unsigned; avoid duplicating the SSSE3 code. + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +// ------------------------------ RotateRight (ShiftRight, Or) + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<15>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<31>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Vec128{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 + return VecFromMask(v < Zero(d)); +#else + // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift + // avoids generating a zero. + const RepartitionToNarrow d32; + const auto sign = ShiftRight<31>(BitCast(d32, v)); + return Vec128{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_abs_epi64(v.raw)}; +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srai_epi64(v.raw, kBits)}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ ZeroIfNegative (BroadcastSignBit) +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + static_assert(IsFloat(), "Only works for float"); + const DFromV d; +#if HWY_TARGET == HWY_SSSE3 + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#else + const auto mask = MaskFromVec(v); // MSB is sufficient for BLENDVPS +#endif + return IfThenElse(mask, Zero(d), v); +} + +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(const Vec128 v, + const Vec128 yes, + const Vec128 no) { + // int8: IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + return BitCast(d, IfThenElse(MaskFromVec(BitCast(df, v)), BitCast(df, yes), + BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { + return Vec64{_mm_mul_sd(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator/(const Vec64 a, const Vec64 b) { + return Vec64{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ss(v.raw)}; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ss(v.raw)}; +} +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec64 Sqrt(const Vec64 v) { + return Vec64{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ss(v.raw)}; +} + +// ------------------------------ Min (Gt, IfThenElse) + +namespace detail { + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MinU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epu64(a.raw, b.raw)}; +#else + return detail::MinU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +namespace detail { +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MaxU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epu64(a.raw, b.raw)}; +#else + return detail::MaxU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_pd(a.raw, b.raw)}; +} + +// ================================================== MEMORY (2) + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(runtime/int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_epi32(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_epi32(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_epi64(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_epi64(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, 8); + } +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_ps(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_ps(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_pd(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_pd(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, 8); + } +} +#else // HWY_TARGET <= HWY_AVX3 + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +#else + +namespace detail { + +template +HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<4> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<4> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<8> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<8> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec128 GatherOffset(Simd d, const T* HWY_RESTRICT base, + const Vec128 offset) { + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec128 GatherIndex(Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_ps(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_ps(base, index.raw, 4)}; +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_pd(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_pd(base, index.raw, 8)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +// Returns upper/lower half of a vector. +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return LowerHalf(Simd(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec128{_mm_slli_si128(v.raw, kBytes)}; +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128 vfull{v.raw}; + v = Vec128{IfThenElseZero(FirstN(Full128(), N), vfull).raw}; + } + return Vec128{_mm_srli_si128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec64 UpperHalf(Half> /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_epi64(v.raw, v.raw)}; +} +HWY_API Vec128 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec128{_mm_movehl_ps(v.raw, v.raw)}; +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128{upper.raw}; +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const int pair = _mm_extract_epi16(v.raw, kLane / 2); + constexpr int kShift = kLane & 1 ? 8 : 0; + return static_cast((pair >> kShift) & 0xFF); +#else + return static_cast(_mm_extract_epi8(v.raw, kLane) & 0xFF); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast(_mm_extract_epi16(v.raw, kLane) & 0xFFFF); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + return static_cast(_mm_extract_epi32(v.raw, kLane)); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + return static_cast(_mm_extract_epi64(v.raw, kLane)); +#endif +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) float lanes[4]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + // Bug in the intrinsic, returns int but should be float. + const int32_t bits = _mm_extract_ps(v.raw, kLane); + float ret; + CopySameSize(&bits, &ret); + return ret; +#endif +} + +// There is no extract_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane == 0, "Lane index out of bounds"); + return GetLane(v); +} + +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane < 2, "Lane index out of bounds"); + const Half> dh; + return kLane == 0 ? GetLane(v) : GetLane(UpperHalf(dh, v)); +} + +} // namespace detail + +// Requires one overload per vector length because ExtractLane<3> may be a +// compile error if it calls _mm_extract_epi64. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128{_mm_insert_epi8(v.raw, t, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{_mm_insert_epi16(v.raw, t, kLane)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + const DFromV d; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128{_mm_insert_epi32(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128{_mm_insert_epi64(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + alignas(16) float lanes[4]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128{_mm_insert_ps(v.raw, _mm_set_ss(t), kLane << 4)}; +#endif +} + +// There is no insert_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane == 0, "Lane index out of bounds"); + return Set(DFromV(), t); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV d; + const Vec128 vt = Set(d, t); + if (kLane == 0) { + return Vec128{_mm_shuffle_pd(vt.raw, v.raw, 2)}; + } + return Vec128{_mm_shuffle_pd(v.raw, vt.raw, 0)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls _mm_insert_epi64. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full128 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec128{_mm_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +template > +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full128 d_full8; + using V8 = VFromD; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, N)))); +#endif + +#if HWY_TARGET <= HWY_AVX2 + (void)d; + return Indices128{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif +} + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#else + (void)d; +#endif + + // No change - even without AVX3, we can shuffle+blend. + return Indices128{vec.raw}; +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + const DFromV d; + const RebindToFloat df; + const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; + return BitCast(d, perm); +#else + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; +#else + const DFromV df; + const RebindToSigned di; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +#endif +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const Full128 d; + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const Full128 df; + return BitCast( + d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128 di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + vidx += vidx; // bit1 is the decider (unusual) + return Vec128{_mm_permutevar_pd(v.raw, vidx.raw)}; +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128 d; + const Full128 di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Full64 /* tag */, const Vec128 v) { + return Vec128{Shuffle2301(Vec128{v.raw}).raw}; +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + if (N == 1) return v; + if (N == 2) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); + } + const RebindToSigned di; + alignas(16) constexpr int16_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + const Vec128 idx = Load(di, kReverse + (N == 8 ? 0 : 4)); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +#endif +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return v; +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + alignas(16) const T kShuffle[16] = {1, 0, 3, 2, 5, 4, 7, 6, + 9, 8, 11, 10, 13, 12, 15, 14}; + return TableLookupBytes(v, Load(d, kShuffle)); +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToSigned di; + // 4x 16-bit: a single shufflelo suffices. + if (N == 4) { + return BitCast(d, Vec128{_mm_shufflelo_epi16( + BitCast(di, v).raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET <= HWY_AVX3 + alignas(16) constexpr int16_t kReverse4[8] = {3, 2, 1, 0, 7, 6, 5, 4}; + const Vec128 idx = Load(di, kReverse4); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +// 4x 32-bit: use Shuffle0123 +template +HWY_API Vec128 Reverse4(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec128 idx = Load(di, kReverse8); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec128 Reverse8(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Additional overload for the optional tag (also for 256/512). +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +// Full +template > +HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, + Vec128 lo_half) { + const Half d2; + const RebindToUnsigned du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 ZeroExtendVector(hwy::NonFloatTag /*tag*/, + Full128 /* d */, Vec64 lo) { + return Vec128{_mm_move_epi64(lo.raw)}; +} + +template +HWY_INLINE Vec128 ZeroExtendVector(hwy::FloatTag /*tag*/, Full128 d, + Vec64 lo) { + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(Half(), lo))); +} + +} // namespace detail + +template +HWY_API Vec128 ZeroExtendVector(Full128 d, Vec64 lo) { + return detail::ZeroExtendVector(hwy::IsFloatTag(), d, lo); +} + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec128 ConcatUpperUpper(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API Vec128 ConcatLowerUpper(Full128 d, const Vec128 hi, + const Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec128 ConcatUpperLower(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(dd, lo).raw, BitCast(dd, hi).raw, + _MM_SHUFFLE2(1, 0))}); +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _pd can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128 ConcatUpperLower(Full128 d, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET == HWY_SSSE3 + (void)d; + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 2, 1, 0))}; +#else + // _mm_shuffle_ps has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + const RepartitionToWide dd; + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128 ConcatUpperLower(Full128 /* tag */, + Vec128 hi, Vec128 lo) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_shuffle_pd(lo.raw, hi.raw, _MM_SHUFFLE2(1, 0))}; +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return Vec128{_mm_blend_pd(hi.raw, lo.raw, 1)}; +#endif +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API Vec128 ConcatLowerLower(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatUpperUpper(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); + return Vec128{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API Vec64 ConcatOdd(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[8] = {1, 3, 5, 7}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactOddU8)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template +HWY_API Vec32 ConcatOdd(Simd d, Vec32 hi, Vec32 lo) { + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[4] = {1, 3}; + const Vec32 shuf = BitCast(d, Load(Full32(), kCompactOddU8)); + const Vec32 L = TableLookupBytes(lo, shuf); + const Vec32 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const Repartition dw; + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); + return Vec128{_mm_packs_epi32(uL.raw, uH.raw)}; +} + +// 16-bit x4 +template +HWY_API Vec64 ConcatOdd(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU16[8] = {2, 3, 6, 7}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactOddU16)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(3, 1, 3, 1))}); +} +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, + Vec128 lo) { + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; +} + +// Any type x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec128 mask = Set(dw, 0x00FF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return Vec128{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API Vec64 ConcatEven(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[8] = {0, 2, 4, 6}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactEvenU8)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template +HWY_API Vec32 ConcatEven(Simd d, Vec32 hi, Vec32 lo) { + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[4] = {0, 2}; + const Vec32 shuf = BitCast(d, Load(Full32(), kCompactEvenU8)); + const Vec32 L = TableLookupBytes(lo, shuf); + const Vec32 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { +#if HWY_TARGET <= HWY_SSE4 + // Isolate lower 16 bits per u32 so we can pack. + const Repartition dw; + const Vec128 mask = Set(dw, 0x0000FFFF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return Vec128{_mm_packus_epi32(uL.raw, uH.raw)}; +#else + // packs_epi32 saturates 0x8000 to 0x7FFF. Instead ConcatEven within the two + // inputs, then concatenate them. + alignas(16) const T kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + const Vec128 shuf = BitCast(d, Load(d, kCompactEvenU16)); + const Vec128 L = TableLookupBytes(lo, shuf); + const Vec128 H = TableLookupBytes(hi, shuf); + return ConcatLowerLower(d, H, L); +#endif +} + +// 16-bit x4 +template +HWY_API Vec64 ConcatEven(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU16[8] = {0, 1, 4, 5}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactEvenU16)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, + Vec128 lo) { + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +#else + return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x55)}; +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i odd = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128i even = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(even, odd)}; +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _ps can do 3/cycle. + const DFromV d; + const RebindToFloat df; + return BitCast(d, Vec128{_mm_blend_ps(BitCast(df, a).raw, + BitCast(df, b).raw, 5)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const RebindToFloat dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd( + BitCast(dd, b).raw, BitCast(dd, a).raw, _MM_SHUFFLE2(1, 0))}); +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, a).raw, + BitCast(dd, b).raw, 1)}); +#endif +} + +template +HWY_API Vec128 OddEven(Vec128 a, Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // SHUFPS must fill the lower half of the output from one input, so we + // need another shuffle. Unpack avoids another immediate byte. + const __m128 odd = _mm_shuffle_ps(a.raw, a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128 even = _mm_shuffle_ps(b.raw, b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_ps(even, odd)}; +#else + return Vec128{_mm_blend_ps(a.raw, b.raw, 5)}; +#endif +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +namespace detail { +#if HWY_TARGET > HWY_AVX3 // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // See comment below. + const Vec128 bits0{_mm_cvtps_epi32(BitCast(df, f0).raw)}; + const Vec128 bits1{_mm_cvtps_epi32(BitCast(df, f1).raw)}; + return Vec128, N>{_mm_packus_epi32(bits0.raw, bits1.raw)}; +} + +// Same, for 32-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. cvt instead of cvtt should be equivalent, but avoids test + // failure under GCC 10.2.1. + return Vec128, N>{_mm_cvtps_epi32(_mm_castsi128_ps(f.raw))}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { + return Vec128{_mm_sll_epi16(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return v * Pow2(bits); +#else + return Vec128{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sll_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128 out0{_mm_sll_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_sll_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128(), out1, out0); +#else + return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 Shl(hwy::UnsignedTag /*tag*/, Vec64 v, + Vec64 bits) { + return Vec64{_mm_sll_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for SSSE3/SSE4 or 16-bit. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; +#else + const Simd d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { + return Vec128{_mm_srl_epi16(in.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const Simd d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128 mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = BitCast(d32, MulEven(in31, mul31)); // 3 ? 1 ? + const Vec128 out = OddEven(out31, BitCast(d32, out20)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { + return Vec128{_mm_srl_epi32(in.raw, bits.raw)}; +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128 out0{_mm_srl_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_srl_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128(), out1, out0); +#else + return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_srl_epi64(v.raw, bits.raw)}; +} + +#if HWY_TARGET > HWY_AVX3 // AVX2 or older +namespace detail { + +// Also used in x86_256-inl.h. +template +HWY_INLINE V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sra_epi16(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sra_epi32(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template > +HWY_API V ReorderWidenMulAccumulate(Simd df32, VFromD a, + VFromD b, const V sum0, V& sum1) { + // TODO(janwas): _mm_dpbf16_ps when available + const RebindToUnsigned du32; + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip + // leads to the odd/even order that RearrangeToOddPlusEven prefers. + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd /*d32*/, Vec128 a, + Vec128 b, const Vec128 sum0, + Vec128& /*sum1*/) { + return sum0 + Vec128{_mm_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return Vec128{_mm_unpacklo_epi8(v.raw, zero)}; +#else + return Vec128{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return Vec128{_mm_cvtepu16_epi32(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; +#else + return Vec128{_mm_cvtepu32_epi64(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return Vec128{_mm_unpacklo_epi16(u16, zero)}; +#else + return Vec128{_mm_cvtepu8_epi32(v.raw)}; +#endif +} + +// Unsigned to signed: same plus cast. +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<8>(Vec128{_mm_unpacklo_epi8(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi8_epi16(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<16>(Vec128{_mm_unpacklo_epi16(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi16_epi32(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<32>(Vec128{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi32_epi64(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(Vec128{x4}); +#else + return Vec128{_mm_cvtepi8_epi32(v.raw)}; +#endif +} + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 Vec128 PromoteTo(Simd df32, + const Vec128 v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec128{_mm_cvtph_ps(v.raw)}; +#endif +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtps_pd(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const Simd di32; + const Simd du16; + const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); + const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); + const auto clamped = Or(zero_if_neg, too_big); + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16, kLower2Bytes); + return Vec128{TableLookupBytes(BitCast(du16, clamped), lo2).raw}; +#else + return Vec128{_mm_packus_epi32(v.raw, v.raw)}; +#endif +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi32(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packus_epi16(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi16(v.raw, v.raw)}; +} + +// Work around MSVC warning for _mm_cvtps_ph (8 is actually a valid immediate). +// clang-cl requires a non-empty string, so we 'ignore' the irrelevant -Wmain. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + // TODO(janwas): _mm_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + // TODO(janwas): _mm_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Full128 /*d16*/, + Vec128 a, Vec128 b) { + return Vec128{_mm_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// For well-defined float->int demotion in all x86_*-inl.h. + +template +HWY_INLINE auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) + -> decltype(Zero(d)) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +// For ConvertTo float->int of same size, clamping before conversion would +// change the result because the max integer value is not exactly representable. +// Instead detect the overflow result after conversion and fix it. +template > +HWY_INLINE auto FixConversionOverflow(DI di, VFromD original, + decltype(Zero(di).raw) converted_raw) + -> VFromD { + // Combinations of original and output sign: + // --: normal <0 or -huge_val to 80..00: OK + // -+: -0 to 0 : OK + // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF + // ++: normal >0 : OK + const auto converted = VFromD{converted_raw}; + const auto sign_wrong = AndNot(BitCast(di, original), converted); +#if HWY_COMPILER_GCC_ACTUAL + // Critical GCC 11 compiler bug (possibly also GCC 10): omits the Xor; also + // Add() if using that instead. Work around with one more instruction. + const RebindToUnsigned du; + const VFromD mask = BroadcastSignBit(sign_wrong); + const VFromD max = BitCast(di, ShiftRight<1>(BitCast(du, mask))); + return IfVecThenElse(mask, max, converted); +#else + return Xor(converted, BroadcastSignBit(sign_wrong)); +#endif +} + +} // namespace detail + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto clamped = detail::ClampF64ToI32Max(Simd(), v); + return Vec128{_mm_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const Simd d32; + const Simd d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + static_assert(!IsSigned() && !IsSigned(), "Unsigned only"); + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d8; + alignas(16) static constexpr uint8_t kMap[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + return LowerHalf(LowerHalf(LowerHalf(TableLookupBytes(v, Load(d8, kMap))))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d16; + alignas(16) static constexpr uint16_t kMap[8] = { + 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d16, kMap)))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x88)}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + alignas(16) static constexpr uint8_t kMap[16] = { + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu, + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d, kMap)))); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_ps(v.raw)}; +} + +template +HWY_API Vec128 ConvertTo(HWY_MAYBE_UNUSED Simd df, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +template +HWY_API Vec128 ConvertTo(Simd dd, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec128{_mm_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +template +HWY_API Vec128 ConvertTo(HWY_MAYBE_UNUSED Simd dd, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double128_fast = [&dd](VU w) HWY_ATTR { + w = Or(w, VU{detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double128_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double128_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(const Simd di, + const Vec128 v) { + return detail::FixConversionOverflow(di, v, _mm_cvttps_epi32(v.raw)); +} + +// Full (partial handled below) +HWY_API Vec128 ConvertTo(Full128 di, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 + return detail::FixConversionOverflow(di, v, _mm_cvttpd_epi64(v.raw)); +#elif HWY_ARCH_X86_64 + const __m128i i0 = _mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw)); + const Half> dd2; + const __m128i i1 = _mm_cvtsi64_si128(_mm_cvttsd_si64(UpperHalf(dd2, v).raw)); + return detail::FixConversionOverflow(di, v, _mm_unpacklo_epi64(i0, i1)); +#else + using VI = VFromD; + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} +HWY_API Vec64 ConvertTo(Full64 di, const Vec64 v) { + // Only need to specialize for non-AVX3, 64-bit (single scalar op) +#if HWY_TARGET > HWY_AVX3 && HWY_ARCH_X86_64 + const Vec64 i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; + return detail::FixConversionOverflow(di, v, i0.raw); +#else + (void)di; + const auto full = ConvertTo(Full128(), Vec128{v.raw}); + return Vec64{full.raw}; +#endif +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const Simd di; + return detail::FixConversionOverflow(di, v, _mm_cvtps_epi32(v.raw)); +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +#if HWY_TARGET == HWY_SSSE3 + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + // Rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const Simd df; + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#else + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +#endif // !HWY_SSSE3 + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask128{_mm_cmpunord_ps(v.raw, v.raw)}; +#endif +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask128{_mm_cmpunord_pd(v.raw, v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ps_mask(v.raw, 0x18)}; +} +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ps_mask(v.raw, 0x99)}); +} +template +HWY_API Mask128 IsFinite(const Vec128 v) { + return Not(Mask128{_mm_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template +HWY_API Mask128 IsInf(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD exp = + BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenc_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x00)}; +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x11)}; +} + +#endif // !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128 vbits{_mm_cvtsi32_si128(static_cast(mask_bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask128::FromBits(mask_bits); +#else + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits(d, mask_bits); +#endif +} + +template +struct CompressIsPartition { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 supports native compress, but a table-based approach allows + // 'partitioning' (also moving mask=false lanes to the top), which helps + // vqsort. This is only feasible for eight or less lanes, i.e. sizeof(T) == 8 + // on AVX3. For simplicity, we only use tables for 64-bit lanes (not AVX3 + // u32x8 etc.). + enum { value = (sizeof(T) == 8) }; +#else + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +#endif +}; + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = (1 << N) - 1; + bits[0] = static_cast(bits[0] & mask_bits); + } + + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +template +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return PopCount(mask_bits); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return mask_bits == 0; +} + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + // Cannot use _kortestc because we may have less than 8 mask bits. + return mask_bits == (1u << N) - 1; +} + +// ------------------------------ Compress + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return Vec128{_mm_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + HWY_DASSERT(mask.raw < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Full128 d; + const Repartition d8; + const auto index = Load(d8, u8_indices + 16 * mask.raw); + return BitCast(d, TableLookupBytes(BitCast(d8, v), index)); +} + +// ------------------------------ CompressNot (Compress) + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // See CompressIsPartition, PrintCompressNot64x2NibbleTables + alignas(16) constexpr uint64_t packed_array[16] = {0x00000010, 0x00000001, + 0x00000010, 0x00000010}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm_permutexvar_epi64 will ignore the upper bits. + const Full128 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(16) constexpr uint64_t shifts[2] = {0, 4}; + const auto indices = Indices128{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + float* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + double* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + // AVX-512 already does the blending at no extra cost (latency 11, + // rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { + // We're relying on the mask to blend. Clear the undefined upper bits. + if (N != 16 / sizeof(T)) { + m = And(m, FirstN(d, N)); + } + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + const Vec128 compressed = Compress(v, m); +#if HWY_MEM_OPS_MIGHT_FAULT + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(16) T buf[N]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 or below + +// ------------------------------ StoreMaskBits + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast(static_cast(mask_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + const Simd d; + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_ps(sign_bits.raw)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_pd(sign_bits.raw)); +} + +// Returns the lowest N of the _mm_movemask* bits. +template +constexpr uint64_t OnlyActive(uint64_t mask_bits) { + return ((N * sizeof(T)) == 16) ? mask_bits : mask_bits & ((1ull << N) - 1); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { + constexpr uint64_t kAllBits = + detail::OnlyActive((1ull << (16 / sizeof(T))) - 1); + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const Simd d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const Simd d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return detail::CompressBits(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNotBits(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Shuffle2301(v10); +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Shuffle2301(v10)); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Shuffle2301(v10)); +} + +// N=4 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +// u8, N=8, N=16: +HWY_API Vec64 SumOfLanes(hwy::SizeTag<1> /* tag */, Vec64 v) { + const Full64 d; + return Set(d, static_cast(GetLane(SumsOf8(v)) & 0xFF)); +} +HWY_API Vec128 SumOfLanes(hwy::SizeTag<1> /* tag */, + Vec128 v) { + const Full128 d; + Vec128 sums = SumOfLanes(hwy::SizeTag<8>(), SumsOf8(v)); + return Set(d, static_cast(GetLane(sums) & 0xFF)); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto is_neg = v < Zero(d); + + // Sum positive and negative lanes separately, then combine to get the result. + const auto positive = SumsOf8(BitCast(du, IfThenZeroElse(is_neg, v))); + const auto negative = SumsOf8(BitCast(du, IfThenElseZero(is_neg, Abs(v)))); + return Set(d, static_cast(GetLane( + SumOfLanes(hwy::SizeTag<8>(), positive - negative)) & + 0xFF)); +} + +#if HWY_TARGET <= HWY_SSE4 +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + using V = decltype(v); + return Broadcast<0>(V{_mm_minpos_epu16(v.raw)}); +} +HWY_API Vec64 MinOfLanes(hwy::SizeTag<1> /* tag */, Vec64 v) { + const Full64 d; + const Full128 d16; + return TruncateTo(d, MinOfLanes(hwy::SizeTag<2>(), PromoteTo(d16, v))); +} +HWY_API Vec128 MinOfLanes(hwy::SizeTag<1> tag, + Vec128 v) { + const Half> d; + Vec64 result = + Min(MinOfLanes(tag, UpperHalf(d, v)), MinOfLanes(tag, LowerHalf(d, v))); + return Combine(DFromV(), result, result); +} + +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> tag, Vec128 v) { + const Vec128 m(Set(DFromV(), LimitsMax())); + return m - MinOfLanes(tag, m - v); +} +HWY_API Vec64 MaxOfLanes(hwy::SizeTag<1> tag, Vec64 v) { + const Vec64 m(Set(DFromV(), LimitsMax())); + return m - MinOfLanes(tag, m - v); +} +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<1> tag, Vec128 v) { + const Vec128 m(Set(DFromV(), LimitsMax())); + return m - MinOfLanes(tag, m - v); +} +#elif HWY_TARGET == HWY_SSSE3 +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128 v) { + const DFromV d; + const RepartitionToWide d16; + const RepartitionToWide d32; + Vec128 vm = Max(v, Reverse2(d, v)); + vm = Max(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); + vm = Max(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); + if (N > 8) { + const RepartitionToWide d64; + vm = Max(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); + } + return vm; +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128 v) { + const DFromV d; + const RepartitionToWide d16; + const RepartitionToWide d32; + Vec128 vm = Min(v, Reverse2(d, v)); + vm = Min(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); + vm = Min(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); + if (N > 8) { + const RepartitionToWide d64; + vm = Min(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); + } + return vm; +} +#endif + +// Implement min/max of i8 in terms of u8 by toggling the sign bit. +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<1> tag, + const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto mask = SignBit(du); + const auto vu = Xor(BitCast(du, v), mask); + return BitCast(d, Xor(MinOfLanes(tag, vu), mask)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<1> tag, + const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto mask = SignBit(du); + const auto vu = Xor(BitCast(du, v), mask); + return BitCast(d, Xor(MaxOfLanes(tag, vu), mask)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Lt128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Eq128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template > +HWY_INLINE V Ne128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template > +HWY_INLINE V Lt128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template > +HWY_INLINE V Eq128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template > +HWY_INLINE V Ne128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template > +HWY_API MFromD Lt128(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template > +HWY_API MFromD Eq128(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template > +HWY_API MFromD Ne128(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template > +HWY_API MFromD Lt128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Eq128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Ne128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) -- cgit v1.2.3