// Copyright 2019 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when // compiling for that target. // External include guard in highway.h - see comment there. // WARNING: most operations do not cross 128-bit block boundaries. In // particular, "Broadcast", pack and zip behavior may be surprising. // Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL #include "hwy/base.h" // Avoid uninitialized warnings in GCC's avx512fintrin.h - see // https://github.com/google/highway/issues/710) HWY_DIAGNOSTICS(push) #if HWY_COMPILER_GCC_ACTUAL HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") #endif // Must come before HWY_COMPILER_CLANGCL #include // AVX2+ #if HWY_COMPILER_CLANGCL // Including should be enough, but Clang's headers helpfully skip // including these headers when _MSC_VER is defined, like when using clang-cl. // Include these directly here. #include // avxintrin defines __m256i and must come before avx2intrin. #include #include // _pext_u64 #include #include #include #endif // HWY_COMPILER_CLANGCL #include #include #include // memcpy #if HWY_IS_MSAN #include #endif // For half-width vectors. Already includes base.h and shared-inl.h. #include "hwy/ops/x86_128-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { namespace detail { template struct Raw256 { using type = __m256i; }; template <> struct Raw256 { using type = __m256; }; template <> struct Raw256 { using type = __m256d; }; } // namespace detail template class Vec256 { using Raw = typename detail::Raw256::type; public: using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec256& operator*=(const Vec256 other) { return *this = (*this * other); } HWY_INLINE Vec256& operator/=(const Vec256 other) { return *this = (*this / other); } HWY_INLINE Vec256& operator+=(const Vec256 other) { return *this = (*this + other); } HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } HWY_INLINE Vec256& operator|=(const Vec256 other) { return *this = (*this | other); } HWY_INLINE Vec256& operator^=(const Vec256 other) { return *this = (*this ^ other); } Raw raw; }; #if HWY_TARGET <= HWY_AVX3 namespace detail { // Template arg: sizeof(lane type) template struct RawMask256 {}; template <> struct RawMask256<1> { using type = __mmask32; }; template <> struct RawMask256<2> { using type = __mmask16; }; template <> struct RawMask256<4> { using type = __mmask8; }; template <> struct RawMask256<8> { using type = __mmask8; }; } // namespace detail template struct Mask256 { using Raw = typename detail::RawMask256::type; static Mask256 FromBits(uint64_t mask_bits) { return Mask256{static_cast(mask_bits)}; } Raw raw; }; #else // AVX2 // FF..FF or 0. template struct Mask256 { typename detail::Raw256::type raw; }; #endif // HWY_TARGET <= HWY_AVX3 template using Full256 = Simd; // ------------------------------ BitCast namespace detail { HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } HWY_INLINE __m256i BitCastToInteger(__m256d v) { return _mm256_castpd_si256(v); } template HWY_INLINE Vec256 BitCastToByte(Vec256 v) { return Vec256{BitCastToInteger(v.raw)}; } // Cannot rely on function overloading because return types differ. template struct BitCastFromInteger256 { HWY_INLINE __m256i operator()(__m256i v) { return v; } }; template <> struct BitCastFromInteger256 { HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } }; template <> struct BitCastFromInteger256 { HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } }; template HWY_INLINE Vec256 BitCastFromByte(Full256 /* tag */, Vec256 v) { return Vec256{BitCastFromInteger256()(v.raw)}; } } // namespace detail template HWY_API Vec256 BitCast(Full256 d, Vec256 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } // ------------------------------ Set // Returns an all-zero vector. template HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{_mm256_setzero_si256()}; } HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{_mm256_setzero_ps()}; } HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{_mm256_setzero_pd()}; } // Returns a vector with all lanes set to "t". HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { return Vec256{_mm256_set1_epi32(static_cast(t))}; } HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { return Vec256{ _mm256_set1_epi64x(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { return Vec256{_mm256_set1_epi32(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { return Vec256{ _mm256_set1_epi64x(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const float t) { return Vec256{_mm256_set1_ps(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const double t) { return Vec256{_mm256_set1_pd(t)}; } HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template HWY_API Vec256 Undefined(Full256 /* tag */) { // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC // generate an XOR instruction. return Vec256{_mm256_undefined_si256()}; } HWY_API Vec256 Undefined(Full256 /* tag */) { return Vec256{_mm256_undefined_ps()}; } HWY_API Vec256 Undefined(Full256 /* tag */) { return Vec256{_mm256_undefined_pd()}; } HWY_DIAGNOSTICS(pop) // ================================================== LOGICAL // ------------------------------ And template HWY_API Vec256 And(Vec256 a, Vec256 b) { return Vec256{_mm256_and_si256(a.raw, b.raw)}; } HWY_API Vec256 And(const Vec256 a, const Vec256 b) { return Vec256{_mm256_and_ps(a.raw, b.raw)}; } HWY_API Vec256 And(const Vec256 a, const Vec256 b) { return Vec256{_mm256_and_pd(a.raw, b.raw)}; } // ------------------------------ AndNot // Returns ~not_mask & mask. template HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { return Vec256{_mm256_andnot_si256(not_mask.raw, mask.raw)}; } HWY_API Vec256 AndNot(const Vec256 not_mask, const Vec256 mask) { return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; } HWY_API Vec256 AndNot(const Vec256 not_mask, const Vec256 mask) { return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; } // ------------------------------ Or template HWY_API Vec256 Or(Vec256 a, Vec256 b) { return Vec256{_mm256_or_si256(a.raw, b.raw)}; } HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { return Vec256{_mm256_or_ps(a.raw, b.raw)}; } HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { return Vec256{_mm256_or_pd(a.raw, b.raw)}; } // ------------------------------ Xor template HWY_API Vec256 Xor(Vec256 a, Vec256 b) { return Vec256{_mm256_xor_si256(a.raw, b.raw)}; } HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { return Vec256{_mm256_xor_ps(a.raw, b.raw)}; } HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { return Vec256{_mm256_xor_pd(a.raw, b.raw)}; } // ------------------------------ Not template HWY_API Vec256 Not(const Vec256 v) { using TU = MakeUnsigned; #if HWY_TARGET <= HWY_AVX3 const __m256i vu = BitCast(Full256(), v).raw; return BitCast(Full256(), Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); #else return Xor(v, BitCast(Full256(), Vec256{_mm256_set1_epi32(-1)})); #endif } // ------------------------------ Xor3 template HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { #if HWY_TARGET <= HWY_AVX3 const Full256 d; const RebindToUnsigned du; using VU = VFromD; const __m256i ret = _mm256_ternarylogic_epi64( BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); return BitCast(d, VU{ret}); #else return Xor(x1, Xor(x2, x3)); #endif } // ------------------------------ Or3 template HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { #if HWY_TARGET <= HWY_AVX3 const Full256 d; const RebindToUnsigned du; using VU = VFromD; const __m256i ret = _mm256_ternarylogic_epi64( BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); return BitCast(d, VU{ret}); #else return Or(o1, Or(o2, o3)); #endif } // ------------------------------ OrAnd template HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { #if HWY_TARGET <= HWY_AVX3 const Full256 d; const RebindToUnsigned du; using VU = VFromD; const __m256i ret = _mm256_ternarylogic_epi64( BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); return BitCast(d, VU{ret}); #else return Or(o, And(a1, a2)); #endif } // ------------------------------ IfVecThenElse template HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { #if HWY_TARGET <= HWY_AVX3 const Full256 d; const RebindToUnsigned du; using VU = VFromD; return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, BitCast(du, no).raw, 0xCA)}); #else return IfThenElse(MaskFromVec(mask), yes, no); #endif } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { return And(a, b); } template HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { return Or(a, b); } template HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { return Xor(a, b); } // ------------------------------ PopulationCount // 8/16 require BITALG, 32/64 require VPOPCNTDQ. #if HWY_TARGET == HWY_AVX3_DL #ifdef HWY_NATIVE_POPCNT #undef HWY_NATIVE_POPCNT #else #define HWY_NATIVE_POPCNT #endif namespace detail { template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi8(v.raw)}; } template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi16(v.raw)}; } template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi32(v.raw)}; } template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi64(v.raw)}; } } // namespace detail template HWY_API Vec256 PopulationCount(Vec256 v) { return detail::PopulationCount(hwy::SizeTag(), v); } #endif // HWY_TARGET == HWY_AVX3_DL // ================================================== SIGN // ------------------------------ CopySign template HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const Full256 d; const auto msb = SignBit(d); #if HWY_TARGET <= HWY_AVX3 const Rebind, decltype(d)> du; // Truth table for msb, magn, sign | bitwise msb ? sign : mag // 0 0 0 | 0 // 0 0 1 | 0 // 0 1 0 | 1 // 0 1 1 | 1 // 1 0 0 | 0 // 1 0 1 | 1 // 1 1 0 | 0 // 1 1 1 | 1 // The lane size does not matter because we are not using predication. const __m256i out = _mm256_ternarylogic_epi32( BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); return BitCast(d, decltype(Zero(du)){out}); #else return Or(AndNot(msb, magn), And(msb, sign)); #endif } template HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { #if HWY_TARGET <= HWY_AVX3 // AVX3 can also handle abs < 0, so no extra action needed. return CopySign(abs, sign); #else return Or(abs, And(SignBit(Full256()), sign)); #endif } // ================================================== MASK #if HWY_TARGET <= HWY_AVX3 // ------------------------------ IfThenElse // Returns mask ? b : a. namespace detail { // Templates for signed/unsigned integer of a particular size. template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<1> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<2> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<4> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); } HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_ps(no.raw, mask.raw, yes.raw)}; } HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_pd(no.raw, mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; } HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256 mask, Vec256 no) { // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. return Vec256{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; } } // namespace detail template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; } HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_API Vec256 ZeroIfNegative(const Vec256 v) { static_assert(IsSigned(), "Only for float"); // AVX3 MaskFromVec only looks at the MSB return IfThenZeroElse(MaskFromVec(v), v); } // ------------------------------ Mask logical namespace detail { template HWY_INLINE Mask256 And(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 And(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 And(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 And(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxnor_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; #endif } template HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxnor_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; #endif } template HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxnor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; #endif } template HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; #else return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; #endif } } // namespace detail template HWY_API Mask256 And(const Mask256 a, Mask256 b) { return detail::And(hwy::SizeTag(), a, b); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { return detail::AndNot(hwy::SizeTag(), a, b); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { return detail::Or(hwy::SizeTag(), a, b); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { return detail::Xor(hwy::SizeTag(), a, b); } template HWY_API Mask256 Not(const Mask256 m) { // Flip only the valid bits. constexpr size_t N = 32 / sizeof(T); return Xor(m, Mask256::FromBits((1ull << N) - 1)); } template HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { return detail::ExclusiveNeither(hwy::SizeTag(), a, b); } #else // AVX2 // ------------------------------ Mask // Mask and Vec are the same (true = FF..FF). template HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{v.raw}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{v.raw}; } template HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { return Vec256{v.raw}; } // ------------------------------ IfThenElse // mask ? yes : no template HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, const Vec256 no) { return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; } HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, const Vec256 no) { return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; } HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, const Vec256 no) { return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; } // mask ? yes : 0 template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return yes & VecFromMask(Full256(), mask); } // mask ? 0 : no template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return AndNot(VecFromMask(Full256(), mask), no); } template HWY_API Vec256 ZeroIfNegative(Vec256 v) { static_assert(IsSigned(), "Only for float"); const auto zero = Zero(Full256()); // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes return IfThenElse(MaskFromVec(v), zero, v); } // ------------------------------ Mask logical template HWY_API Mask256 Not(const Mask256 m) { return MaskFromVec(Not(VecFromMask(Full256(), m))); } template HWY_API Mask256 And(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } #endif // HWY_TARGET <= HWY_AVX3 // ================================================== COMPARE #if HWY_TARGET <= HWY_AVX3 // Comparisons set a mask bit to 1 if the condition is true, else 0. template HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask256{m.raw}; } namespace detail { template HWY_INLINE Mask256 TestBit(hwy::SizeTag<1> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi8_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask256 TestBit(hwy::SizeTag<2> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi16_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask256 TestBit(hwy::SizeTag<4> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi32_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask256 TestBit(hwy::SizeTag<8> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi64_mask(v.raw, bit.raw)}; } } // namespace detail template HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return detail::TestBit(hwy::SizeTag(), v, bit); } // ------------------------------ Equality template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator==(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } HWY_API Mask256 operator==(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } // ------------------------------ Inequality template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } // ------------------------------ Strict inequality HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; } // ------------------------------ Weak inequality HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; } HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; } // ------------------------------ Mask namespace detail { template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi8_mask(v.raw)}; } template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi16_mask(v.raw)}; } template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi32_mask(v.raw)}; } template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi64_mask(v.raw)}; } } // namespace detail template HWY_API Mask256 MaskFromVec(const Vec256 v) { return detail::MaskFromVec(hwy::SizeTag(), v); } // There do not seem to be native floating-point versions of these instructions. HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; } HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi8(v.raw)}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi16(v.raw)}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi32(v.raw)}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi64(v.raw)}; } HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; } HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; } template HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { return VecFromMask(v); } #else // AVX2 // Comparisons fill a lane with 1-bits if the condition is true, else 0. template HWY_API Mask256 RebindMask(Full256 d_to, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return MaskFromVec(BitCast(d_to, VecFromMask(Full256(), m))); } template HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } // ------------------------------ Equality template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; } // ------------------------------ Inequality template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Not(a == b); } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; } // ------------------------------ Strict inequality // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { // Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 // to perform an unsigned comparison instead of the intended signed. Workaround // is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy #if HWY_COMPILER_GCC != 0 && HWY_COMPILER_GCC < 930 #define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 #else #define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 #endif HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, Vec256 b) { #if HWY_AVX2_GCC_CMPGT8_WORKAROUND using i8x32 = signed char __attribute__((__vector_size__(32))); return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > reinterpret_cast(b.raw))}; #else return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; #endif } HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; } HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; } HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; } template HWY_INLINE Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, Vec256 b) { const Full256 du; const RebindToSigned di; const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); } HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; } HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; } } // namespace detail template HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return detail::Gt(hwy::TypeTag(), a, b); } // ------------------------------ Weak inequality HWY_API Mask256 operator>=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; } HWY_API Mask256 operator>=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; } #endif // HWY_TARGET <= HWY_AVX3 // ------------------------------ Reversed comparisons template HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { return b > a; } template HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { return b >= a; } // ------------------------------ Min (Gt, IfThenElse) // Unsigned HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epu8(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epu16(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epu32(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_min_epu64(a.raw, b.raw)}; #else const Full256 du; const Full256 di; const auto msb = Set(du, 1ull << 63); const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); return IfThenElse(gt, b, a); #endif } // Signed HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epi8(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epi16(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epi32(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_min_epi64(a.raw, b.raw)}; #else return IfThenElse(a < b, a, b); #endif } // Float HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_ps(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_pd(a.raw, b.raw)}; } // ------------------------------ Max (Gt, IfThenElse) // Unsigned HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epu8(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epu16(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epu32(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_max_epu64(a.raw, b.raw)}; #else const Full256 du; const Full256 di; const auto msb = Set(du, 1ull << 63); const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); return IfThenElse(gt, a, b); #endif } // Signed HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epi8(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epi16(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epi32(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_max_epi64(a.raw, b.raw)}; #else return IfThenElse(a < b, b, a); #endif } // Float HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_ps(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_pd(a.raw, b.raw)}; } // ------------------------------ FirstN (Iota, Lt) template HWY_API Mask256 FirstN(const Full256 d, size_t n) { #if HWY_TARGET <= HWY_AVX3 (void)d; constexpr size_t N = 32 / sizeof(T); #if HWY_ARCH_X86_64 const uint64_t all = (1ull << N) - 1; // BZHI only looks at the lower 8 bits of n! return Mask256::FromBits((n > 255) ? all : _bzhi_u64(all, n)); #else const uint32_t all = static_cast((1ull << N) - 1); // BZHI only looks at the lower 8 bits of n! return Mask256::FromBits( (n > 255) ? all : _bzhi_u32(all, static_cast(n))); #endif // HWY_ARCH_X86_64 #else const RebindToSigned di; // Signed comparisons are cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(n))); #endif } // ================================================== ARITHMETIC // ------------------------------ Addition // Unsigned HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi64(a.raw, b.raw)}; } // Float HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_ps(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_pd(a.raw, b.raw)}; } // ------------------------------ Subtraction // Unsigned HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; } // Float HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_ps(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_pd(a.raw, b.raw)}; } // ------------------------------ SumsOf8 HWY_API Vec256 SumsOf8(const Vec256 v) { return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; } // ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. // Unsigned HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; } // ------------------------------ Average // Returns (a + b + 1) / 2 // Unsigned HWY_API Vec256 AverageRound(const Vec256 a, const Vec256 b) { return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; } HWY_API Vec256 AverageRound(const Vec256 a, const Vec256 b) { return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; } // ------------------------------ Abs (Sub) // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. HWY_API Vec256 Abs(const Vec256 v) { #if HWY_COMPILER_MSVC // Workaround for incorrect codegen? (wrong result) const auto zero = Zero(Full256()); return Vec256{_mm256_max_epi8(v.raw, (zero - v).raw)}; #else return Vec256{_mm256_abs_epi8(v.raw)}; #endif } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{_mm256_abs_epi16(v.raw)}; } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{_mm256_abs_epi32(v.raw)}; } // i64 is implemented after BroadcastSignBit. HWY_API Vec256 Abs(const Vec256 v) { const Vec256 mask{_mm256_set1_epi32(0x7FFFFFFF)}; return v & BitCast(Full256(), mask); } HWY_API Vec256 Abs(const Vec256 v) { const Vec256 mask{_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; return v & BitCast(Full256(), mask); } // ------------------------------ Integer multiplication // Unsigned HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; } HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; } HWY_API Vec256 MulFixedPoint15(Vec256 a, Vec256 b) { return Vec256{_mm256_mulhrs_epi16(a.raw, b.raw)}; } // Multiplies even lanes (0, 2 ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; } HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; } // ------------------------------ ShiftLeft template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi64(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi64(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { const Full256 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); return kBits == 1 ? (v + v) : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } // ------------------------------ ShiftRight template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srli_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srli_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srli_epi64(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; return shifted & Set(d8, 0xFF >> kBits); } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srai_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srai_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { const Full256 di; const Full256 du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; } // i64 is implemented after BroadcastSignBit. // ------------------------------ RotateRight template HWY_API Vec256 RotateRight(const Vec256 v) { static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_ror_epi32(v.raw, kBits)}; #else if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); #endif } template HWY_API Vec256 RotateRight(const Vec256 v) { static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_ror_epi64(v.raw, kBits)}; #else if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); #endif } // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return VecFromMask(v < Zero(Full256())); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight<15>(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight<31>(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { #if HWY_TARGET == HWY_AVX2 return VecFromMask(v < Zero(Full256())); #else return Vec256{_mm256_srai_epi64(v.raw, 63)}; #endif } template HWY_API Vec256 ShiftRight(const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srai_epi64(v.raw, kBits)}; #else const Full256 di; const Full256 du; const auto right = BitCast(di, ShiftRight(BitCast(du, v))); const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); return right | sign; #endif } HWY_API Vec256 Abs(const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_abs_epi64(v.raw)}; #else const auto zero = Zero(Full256()); return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } // ------------------------------ IfNegativeThenElse (BroadcastSignBit) HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { // int8: AVX2 IfThenElse only looks at the MSB. return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { static_assert(IsSigned(), "Only works for signed/float"); const Full256 d; const RebindToSigned di; // 16-bit: no native blendv, so copy sign to lower byte's MSB. v = BitCast(d, BroadcastSignBit(BitCast(di, v))); return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { static_assert(IsSigned(), "Only works for signed/float"); const Full256 d; const RebindToFloat df; // 32/64-bit: use float IfThenElse, which only looks at the MSB. const MFromD msb = MaskFromVec(BitCast(df, v)); return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); } // ------------------------------ ShiftLeftSame HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } template HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { const Full256 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); } // ------------------------------ ShiftRightSame (BroadcastSignBit) HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { const Full256 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast(0xFF >> bits)); } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; #else const Full256 di; const Full256 du; const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); return right | sign; #endif } HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { const Full256 di; const Full256 du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, static_cast(0x80 >> bits))); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ Neg (Xor, Sub) // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec256 Neg(hwy::FloatTag /*tag*/, const Vec256 v) { return Xor(v, SignBit(Full256())); } // Not floating-point template HWY_INLINE Vec256 Neg(hwy::NonFloatTag /*tag*/, const Vec256 v) { return Zero(Full256()) - v; } } // namespace detail template HWY_API Vec256 Neg(const Vec256 v) { return detail::Neg(hwy::IsFloatTag(), v); } // ------------------------------ Floating-point mul / div HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mul_ps(a.raw, b.raw)}; } HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mul_pd(a.raw, b.raw)}; } HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { return Vec256{_mm256_div_ps(a.raw, b.raw)}; } HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { return Vec256{_mm256_div_pd(a.raw, b.raw)}; } // Approximate reciprocal HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { return Vec256{_mm256_rcp_ps(v.raw)}; } // Absolute value of difference. HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants // Returns mul * x + add HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x + add; #else return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; #endif } HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x + add; #else return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; #endif } // Returns add - mul * x HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return add - mul * x; #else return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; #endif } HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return add - mul * x; #else return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; #endif } // Returns mul * x - sub HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x - sub; #else return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; #endif } HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x - sub; #else return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; #endif } // Returns -mul * x - sub HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return Neg(mul * x) - sub; #else return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; #endif } HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return Neg(mul * x) - sub; #else return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; #endif } // ------------------------------ Floating-point square root // Full precision square root HWY_API Vec256 Sqrt(const Vec256 v) { return Vec256{_mm256_sqrt_ps(v.raw)}; } HWY_API Vec256 Sqrt(const Vec256 v) { return Vec256{_mm256_sqrt_pd(v.raw)}; } // Approximate reciprocal square root HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { return Vec256{_mm256_rsqrt_ps(v.raw)}; } // ------------------------------ Floating-point rounding // Toward nearest integer, tie to even HWY_API Vec256 Round(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Round(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } // Toward zero, aka truncate HWY_API Vec256 Trunc(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Trunc(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } // Toward +infinity, aka ceiling HWY_API Vec256 Ceil(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Ceil(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } // Toward -infinity, aka floor HWY_API Vec256 Floor(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Floor(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } // ------------------------------ Floating-point classification HWY_API Mask256 IsNaN(const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Mask256{_mm256_fpclass_ps_mask(v.raw, 0x81)}; #else return Mask256{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; #endif } HWY_API Mask256 IsNaN(const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Mask256{_mm256_fpclass_pd_mask(v.raw, 0x81)}; #else return Mask256{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; #endif } #if HWY_TARGET <= HWY_AVX3 HWY_API Mask256 IsInf(const Vec256 v) { return Mask256{_mm256_fpclass_ps_mask(v.raw, 0x18)}; } HWY_API Mask256 IsInf(const Vec256 v) { return Mask256{_mm256_fpclass_pd_mask(v.raw, 0x18)}; } HWY_API Mask256 IsFinite(const Vec256 v) { // fpclass doesn't have a flag for positive, so we have to check for inf/NaN // and negate the mask. return Not(Mask256{_mm256_fpclass_ps_mask(v.raw, 0x99)}); } HWY_API Mask256 IsFinite(const Vec256 v) { return Not(Mask256{_mm256_fpclass_pd_mask(v.raw, 0x99)}); } #else template HWY_API Mask256 IsInf(const Vec256 v) { static_assert(IsFloat(), "Only for float"); const Full256 d; const RebindToSigned di; const VFromD vi = BitCast(di, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. template HWY_API Mask256 IsFinite(const Vec256 v) { static_assert(IsFloat(), "Only for float"); const Full256 d; const RebindToUnsigned du; const RebindToSigned di; // cheaper than unsigned comparison const VFromD vu = BitCast(du, v); // Shift left to clear the sign bit, then right so we can compare with the // max exponent (cannot compare with MaxExponentTimes2 directly because it is // negative and non-negative floats would be greater). MSVC seems to generate // incorrect code if we instead add vu + vu. const VFromD exp = BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); } #endif // HWY_TARGET <= HWY_AVX3 // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{ _mm256_load_si256(reinterpret_cast(aligned))}; } HWY_API Vec256 Load(Full256 /* tag */, const float* HWY_RESTRICT aligned) { return Vec256{_mm256_load_ps(aligned)}; } HWY_API Vec256 Load(Full256 /* tag */, const double* HWY_RESTRICT aligned) { return Vec256{_mm256_load_pd(aligned)}; } template HWY_API Vec256 LoadU(Full256 /* tag */, const T* HWY_RESTRICT p) { return Vec256{_mm256_loadu_si256(reinterpret_cast(p))}; } HWY_API Vec256 LoadU(Full256 /* tag */, const float* HWY_RESTRICT p) { return Vec256{_mm256_loadu_ps(p)}; } HWY_API Vec256 LoadU(Full256 /* tag */, const double* HWY_RESTRICT p) { return Vec256{_mm256_loadu_pd(p)}; } // ------------------------------ MaskedLoad #if HWY_TARGET <= HWY_AVX3 template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT p) { return Vec256{_mm256_maskz_loadu_epi8(m.raw, p)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT p) { return Vec256{_mm256_maskz_loadu_epi16(m.raw, p)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT p) { return Vec256{_mm256_maskz_loadu_epi32(m.raw, p)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT p) { return Vec256{_mm256_maskz_loadu_epi64(m.raw, p)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const float* HWY_RESTRICT p) { return Vec256{_mm256_maskz_loadu_ps(m.raw, p)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const double* HWY_RESTRICT p) { return Vec256{_mm256_maskz_loadu_pd(m.raw, p)}; } #else // AVX2 // There is no maskload_epi8/16, so blend instead. template * = nullptr> HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const T* HWY_RESTRICT p) { return IfThenElseZero(m, LoadU(d, p)); } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT p) { auto pi = reinterpret_cast(p); // NOLINT return Vec256{_mm256_maskload_epi32(pi, m.raw)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT p) { auto pi = reinterpret_cast(p); // NOLINT return Vec256{_mm256_maskload_epi64(pi, m.raw)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const float* HWY_RESTRICT p) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); return Vec256{_mm256_maskload_ps(p, mi.raw)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const double* HWY_RESTRICT p) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); return Vec256{_mm256_maskload_pd(p, mi.raw)}; } #endif // ------------------------------ LoadDup128 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the // 3-cycle cost of moving data between 128-bit halves and avoids port 5. template HWY_API Vec256 LoadDup128(Full256 /* tag */, const T* HWY_RESTRICT p) { #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the // upper half undefined) is fine because we're overwriting that anyway. // This workaround seems in turn to generate incorrect code in MSVC 2022 // (19.31), so use broadcastsi128 there. const __m128i v128 = LoadU(Full128(), p).raw; return Vec256{ _mm256_inserti128_si256(_mm256_castsi128_si256(v128), v128, 1)}; #else return Vec256{_mm256_broadcastsi128_si256(LoadU(Full128(), p).raw)}; #endif } HWY_API Vec256 LoadDup128(Full256 /* tag */, const float* const HWY_RESTRICT p) { #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 const __m128 v128 = LoadU(Full128(), p).raw; return Vec256{ _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; #else return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; #endif } HWY_API Vec256 LoadDup128(Full256 /* tag */, const double* const HWY_RESTRICT p) { #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 const __m128d v128 = LoadU(Full128(), p).raw; return Vec256{ _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; #else return Vec256{ _mm256_broadcast_pd(reinterpret_cast(p))}; #endif } // ------------------------------ Store template HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); } HWY_API void Store(const Vec256 v, Full256 /* tag */, float* HWY_RESTRICT aligned) { _mm256_store_ps(aligned, v.raw); } HWY_API void Store(const Vec256 v, Full256 /* tag */, double* HWY_RESTRICT aligned) { _mm256_store_pd(aligned, v.raw); } template HWY_API void StoreU(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT p) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); } HWY_API void StoreU(const Vec256 v, Full256 /* tag */, float* HWY_RESTRICT p) { _mm256_storeu_ps(p, v.raw); } HWY_API void StoreU(const Vec256 v, Full256 /* tag */, double* HWY_RESTRICT p) { _mm256_storeu_pd(p, v.raw); } // ------------------------------ BlendedStore #if HWY_TARGET <= HWY_AVX3 template HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, T* HWY_RESTRICT p) { _mm256_mask_storeu_epi8(p, m.raw, v.raw); } template HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, T* HWY_RESTRICT p) { _mm256_mask_storeu_epi16(p, m.raw, v.raw); } template HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, T* HWY_RESTRICT p) { _mm256_mask_storeu_epi32(p, m.raw, v.raw); } template HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, T* HWY_RESTRICT p) { _mm256_mask_storeu_epi64(p, m.raw, v.raw); } HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, float* HWY_RESTRICT p) { _mm256_mask_storeu_ps(p, m.raw, v.raw); } HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, double* HWY_RESTRICT p) { _mm256_mask_storeu_pd(p, m.raw, v.raw); } #else // AVX2 // Intel SDM says "No AC# reported for any mask bit combinations". However, AMD // allows AC# if "Alignment checking enabled and: 256-bit memory operand not // 32-byte aligned". Fortunately AC# is not enabled by default and requires both // OS support (CR0) and the application to set rflags.AC. We assume these remain // disabled because x86/x64 code and compiler output often contain misaligned // scalar accesses, which would also fault. // // Caveat: these are slow on AMD Jaguar/Bulldozer. template * = nullptr> HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT p) { // There is no maskload_epi8/16. Blending is also unsafe because loading a // full vector that crosses the array end causes asan faults. Resort to scalar // code; the caller should instead use memcpy, assuming m is FirstN(d, n). const RebindToUnsigned du; using TU = TFromD; alignas(32) TU buf[32 / sizeof(T)]; alignas(32) TU mask[32 / sizeof(T)]; Store(BitCast(du, v), du, buf); Store(BitCast(du, VecFromMask(d, m)), du, mask); for (size_t i = 0; i < 32 / sizeof(T); ++i) { if (mask[i]) { CopySameSize(buf + i, p + i); } } } template HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, T* HWY_RESTRICT p) { auto pi = reinterpret_cast(p); // NOLINT _mm256_maskstore_epi32(pi, m.raw, v.raw); } template HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, T* HWY_RESTRICT p) { auto pi = reinterpret_cast(p); // NOLINT _mm256_maskstore_epi64(pi, m.raw, v.raw); } HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, float* HWY_RESTRICT p) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); _mm256_maskstore_ps(p, mi.raw, v.raw); } HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, double* HWY_RESTRICT p) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); _mm256_maskstore_pd(p, mi.raw, v.raw); } #endif // ------------------------------ Non-temporal stores template HWY_API void Stream(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), v.raw); } HWY_API void Stream(const Vec256 v, Full256 /* tag */, float* HWY_RESTRICT aligned) { _mm256_stream_ps(aligned, v.raw); } HWY_API void Stream(const Vec256 v, Full256 /* tag */, double* HWY_RESTRICT aligned) { _mm256_stream_pd(aligned, v.raw); } // ------------------------------ Scatter // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") #if HWY_TARGET <= HWY_AVX3 namespace detail { template HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 offset) { _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 index) { _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); } template HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 offset) { _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 index) { _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); } } // namespace detail template HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); } template HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); } HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, float* HWY_RESTRICT base, const Vec256 offset) { _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, float* HWY_RESTRICT base, const Vec256 index) { _mm256_i32scatter_ps(base, index.raw, v.raw, 4); } HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, double* HWY_RESTRICT base, const Vec256 offset) { _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, double* HWY_RESTRICT base, const Vec256 index) { _mm256_i64scatter_pd(base, index.raw, v.raw, 8); } #else template HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); constexpr size_t N = 32 / sizeof(T); alignas(32) T lanes[N]; Store(v, d, lanes); alignas(32) Offset offset_lanes[N]; Store(offset, Full256(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); } } template HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); constexpr size_t N = 32 / sizeof(T); alignas(32) T lanes[N]; Store(v, d, lanes); alignas(32) Index index_lanes[N]; Store(index, Full256(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; } } #endif // ------------------------------ Gather namespace detail { template HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<4> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i32gather_epi32( reinterpret_cast(base), offset.raw, 1)}; } template HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<4> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i32gather_epi32( reinterpret_cast(base), index.raw, 4)}; } template HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<8> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i64gather_epi64( reinterpret_cast(base), offset.raw, 1)}; } template HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<8> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i64gather_epi64( reinterpret_cast(base), index.raw, 8)}; } } // namespace detail template HWY_API Vec256 GatherOffset(Full256 d, const T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::GatherOffset(hwy::SizeTag(), d, base, offset); } template HWY_API Vec256 GatherIndex(Full256 d, const T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::GatherIndex(hwy::SizeTag(), d, base, index); } HWY_API Vec256 GatherOffset(Full256 /* tag */, const float* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i32gather_ps(base, offset.raw, 1)}; } HWY_API Vec256 GatherIndex(Full256 /* tag */, const float* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i32gather_ps(base, index.raw, 4)}; } HWY_API Vec256 GatherOffset(Full256 /* tag */, const double* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i64gather_pd(base, offset.raw, 1)}; } HWY_API Vec256 GatherIndex(Full256 /* tag */, const double* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i64gather_pd(base, index.raw, 8)}; } HWY_DIAGNOSTICS(pop) // ================================================== SWIZZLE // ------------------------------ LowerHalf template HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_castsi256_si128(v.raw)}; } HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_castps256_ps128(v.raw)}; } HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_castpd256_pd128(v.raw)}; } template HWY_API Vec128 LowerHalf(Vec256 v) { return LowerHalf(Full128(), v); } // ------------------------------ UpperHalf template HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_extracti128_si256(v.raw, 1)}; } HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_extractf128_ps(v.raw, 1)}; } HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_extractf128_pd(v.raw, 1)}; } // ------------------------------ ExtractLane (Store) template HWY_API T ExtractLane(const Vec256 v, size_t i) { const Full256 d; HWY_DASSERT(i < Lanes(d)); alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); return lanes[i]; } // ------------------------------ InsertLane (Store) template HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { const Full256 d; HWY_DASSERT(i < Lanes(d)); alignas(64) T lanes[64 / sizeof(T)]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } // ------------------------------ GetLane (LowerHalf) template HWY_API T GetLane(const Vec256 v) { return GetLane(LowerHalf(v)); } // ------------------------------ ZeroExtendVector // Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper // bits undefined. Although it makes sense for them to be zero (VEX encoded // 128-bit instructions zero the upper lanes to avoid large penalties), a // compiler could decide to optimize out code that relies on this. // // The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the // zeroing, but it is not available on MSVC until 15.7 nor GCC until 10.1. For // older GCC, we can still obtain the desired code thanks to pattern // recognition; note that the expensive insert instruction is not actually // generated, see https://gcc.godbolt.org/z/1MKGaP. #if !defined(HWY_HAVE_ZEXT) #if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC >= 1915) || \ (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) #define HWY_HAVE_ZEXT 1 #else #define HWY_HAVE_ZEXT 0 #endif #endif // defined(HWY_HAVE_ZEXT) template HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { #if HWY_HAVE_ZEXT return Vec256{_mm256_zextsi128_si256(lo.raw)}; #else return Vec256{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; #endif } HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { #if HWY_HAVE_ZEXT return Vec256{_mm256_zextps128_ps256(lo.raw)}; #else return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; #endif } HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { #if HWY_HAVE_ZEXT return Vec256{_mm256_zextpd128_pd256(lo.raw)}; #else return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; #endif } // ------------------------------ Combine template HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { const auto lo256 = ZeroExtendVector(d, lo); return Vec256{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; } HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { const auto lo256 = ZeroExtendVector(d, lo); return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; } HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { const auto lo256 = ZeroExtendVector(d, lo); return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; } // ------------------------------ ShiftLeftBytes template HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, const Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); // This is the same operation as _mm256_bslli_epi128. return Vec256{_mm256_slli_si256(v.raw, kBytes)}; } template HWY_API Vec256 ShiftLeftBytes(const Vec256 v) { return ShiftLeftBytes(Full256(), v); } // ------------------------------ ShiftLeftLanes template HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { return ShiftLeftLanes(Full256(), v); } // ------------------------------ ShiftRightBytes template HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, const Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); // This is the same operation as _mm256_bsrli_epi128. return Vec256{_mm256_srli_si256(v.raw, kBytes)}; } // ------------------------------ ShiftRightLanes template HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ CombineShiftRightBytes // Extracts 128 bits from by skipping the least-significant kBytes. template > HWY_API V CombineShiftRightBytes(Full256 d, V hi, V lo) { const Repartition d8; return BitCast(d, Vec256{_mm256_alignr_epi8( BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); } // ------------------------------ Broadcast/splat any lane // Unsigned template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec256{_mm256_unpacklo_epi64(lo, lo)}; } else { const __m256i hi = _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec256{_mm256_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; } // Signed template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec256{_mm256_unpacklo_epi64(lo, lo)}; } else { const __m256i hi = _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec256{_mm256_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; } // Float template HWY_API Vec256 Broadcast(Vec256 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; } // ------------------------------ Hard-coded shuffles // Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is // least-significant). Shuffle0321 rotates four-lane blocks one lane to the // right (the previous least-significant lane is now most-significant => // 47650321). These could also be implemented via CombineShiftRightBytes but // the shuffle_abcd notation is more convenient. // Swap 32-bit halves in 64-bit halves. template HWY_API Vec256 Shuffle2301(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; } HWY_API Vec256 Shuffle2301(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; } // Used by generic_ops-inl.h namespace detail { template HWY_API Vec256 Shuffle2301(const Vec256 a, const Vec256 b) { const Full256 d; const RebindToFloat df; constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, m)}); } template HWY_API Vec256 Shuffle1230(const Vec256 a, const Vec256 b) { const Full256 d; const RebindToFloat df; constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, m)}); } template HWY_API Vec256 Shuffle3012(const Vec256 a, const Vec256 b) { const Full256 d; const RebindToFloat df; constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, m)}); } } // namespace detail // Swap 64-bit halves HWY_API Vec256 Shuffle1032(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle1032(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle1032(const Vec256 v) { // Shorter encoding than _mm256_permute_ps. return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; } HWY_API Vec256 Shuffle01(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle01(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle01(const Vec256 v) { // Shorter encoding than _mm256_permute_pd. return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; } // Rotate right 32 bits HWY_API Vec256 Shuffle0321(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; } HWY_API Vec256 Shuffle0321(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; } HWY_API Vec256 Shuffle0321(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; } // Rotate left 32 bits HWY_API Vec256 Shuffle2103(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; } HWY_API Vec256 Shuffle2103(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; } HWY_API Vec256 Shuffle2103(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; } // Reverse HWY_API Vec256 Shuffle0123(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; } HWY_API Vec256 Shuffle0123(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; } HWY_API Vec256 Shuffle0123(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. template struct Indices256 { __m256i raw; }; // Native 8x32 instruction: indices remain unchanged template HWY_API Indices256 IndicesFromVec(Full256 /* tag */, Vec256 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD const Full256 di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(32 / sizeof(T)))))); #endif return Indices256{vec.raw}; } // 64-bit lanes: convert indices to 8x32 unless AVX3 is available template HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 idx64) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); const Rebind di; (void)di; // potentially unused #if HWY_IS_DEBUG_BUILD HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && AllTrue(di, Lt(idx64, Set(di, static_cast(32 / sizeof(T)))))); #endif #if HWY_TARGET <= HWY_AVX3 (void)d; return Indices256{idx64.raw}; #else const Repartition df; // 32-bit! // Replicate 64-bit index into upper 32 bits const Vec256 dup = BitCast(di, Vec256{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); // For each idx64 i, idx32 are 2*i and 2*i+1. const Vec256 idx32 = dup + dup + Set(di, TI(1) << 32); return Indices256{idx32.raw}; #endif } template HWY_API Indices256 SetTableIndices(const Full256 d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; } template HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_permutexvar_epi64(idx.raw, v.raw)}; #else return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; #endif } HWY_API Vec256 TableLookupLanes(const Vec256 v, const Indices256 idx) { return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; } HWY_API Vec256 TableLookupLanes(const Vec256 v, const Indices256 idx) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_permutexvar_pd(idx.raw, v.raw)}; #else const Full256 df; const Full256 du; return BitCast(df, Vec256{_mm256_permutevar8x32_epi32( BitCast(du, v).raw, idx.raw)}); #endif } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return Vec256{_mm256_permute2x128_si256(v.raw, v.raw, 0x01)}; } HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return Vec256{_mm256_permute2f128_ps(v.raw, v.raw, 0x01)}; } HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return Vec256{_mm256_permute2f128_pd(v.raw, v.raw, 0x01)}; } // ------------------------------ Reverse (RotateRight) template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { alignas(32) constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { alignas(32) constexpr int64_t kReverse[4] = {3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 const RebindToSigned di; alignas(32) constexpr int16_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; const Vec256 idx = Load(di, kReverse); return BitCast(d, Vec256{ _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RepartitionToWide> du32; const Vec256 rev32 = Reverse(du32, BitCast(du32, v)); return BitCast(d, RotateRight<16>(rev32)); #endif } // ------------------------------ Reverse2 template HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { const Full256 du32; return BitCast(d, RotateRight<16>(BitCast(du32, v))); } template HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { return Shuffle2301(v); } template HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { return Shuffle01(v); } // ------------------------------ Reverse4 (SwapAdjacentBlocks) template HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 const RebindToSigned di; alignas(32) constexpr int16_t kReverse4[16] = {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12}; const Vec256 idx = Load(di, kReverse4); return BitCast(d, Vec256{ _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RepartitionToWide dw; return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); #endif } template HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { return Shuffle0123(v); } template HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { // Could also use _mm256_permute4x64_epi64. return SwapAdjacentBlocks(Shuffle01(v)); } // ------------------------------ Reverse8 template HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 const RebindToSigned di; alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; const Vec256 idx = Load(di, kReverse8); return BitCast(d, Vec256{ _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RepartitionToWide dw; return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); #endif } template HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { return Reverse(d, v); } template HWY_API Vec256 Reverse8(Full256 /* tag */, const Vec256 /* v */) { HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes } // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides // the least-significant lane) and "b". To concatenate two half-width integers // into one, use ZipLower/Upper instead (also works with scalar). HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; } // ------------------------------ InterleaveUpper // All functions inside detail lack the required D parameter. namespace detail { HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_ps(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_pd(a.raw, b.raw)}; } } // namespace detail template > HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template > HWY_API Vec256 ZipLower(Vec256 a, Vec256 b) { return BitCast(Full256(), InterleaveLower(a, b)); } template > HWY_API Vec256 ZipLower(Full256 dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveLower(a, b)); } template > HWY_API Vec256 ZipUpper(Full256 dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveUpper(Full256(), a, b)); } // ------------------------------ Blocks (LowerHalf, ZeroExtendVector) // _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. // _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no // extra cost) for LowerLower and UpperLower. // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return Vec256{_mm256_inserti128_si256(lo.raw, LowerHalf(d2, hi).raw, 1)}; } HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; } HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; } // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) template HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; } HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; } HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; } // hiH,hiL loH,loL |-> hiH,loL (= outer halves) template HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; } HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; } HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; } // hiH,hiL loH,loL |-> hiH,loH (= upper halves) template HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x31)}; } HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; } HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; } // ------------------------------ ConcatOdd template HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET == HWY_AVX3_DL alignas(32) constexpr uint8_t kIdx[32] = { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi8( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); #else const RepartitionToWide dw; // Unsigned 8-bit shift so we can pack. const Vec256 uH = ShiftRight<8>(BitCast(dw, hi)); const Vec256 uL = ShiftRight<8>(BitCast(dw, lo)); const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); return Vec256{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } template HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(32) constexpr uint16_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi16( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, BitCast(du, hi).raw)}); #else const RepartitionToWide dw; // Unsigned 16-bit shift so we can pack. const Vec256 uH = ShiftRight<16>(BitCast(dw, hi)); const Vec256 uL = ShiftRight<16>(BitCast(dw, lo)); const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); return Vec256{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } template HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v3131{_mm256_shuffle_ps( BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; return Vec256{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else const Vec256 v3131{ _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; return BitCast(d, Vec256{_mm256_permute4x64_epi64( BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); #endif } template HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v31{ _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; return Vec256{ _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { #if HWY_TARGET <= HWY_AVX3 const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else (void)d; const Vec256 v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; return Vec256{ _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } // ------------------------------ ConcatEven template HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET == HWY_AVX3_DL alignas(64) constexpr uint8_t kIdx[32] = { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi8( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); #else const RepartitionToWide dw; // Isolate lower 8 bits per u16 so we can pack. const Vec256 mask = Set(dw, 0x00FF); const Vec256 uH = And(BitCast(dw, hi), mask); const Vec256 uL = And(BitCast(dw, lo), mask); const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); return Vec256{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } template HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint16_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi16( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, BitCast(du, hi).raw)}); #else const RepartitionToWide dw; // Isolate lower 16 bits per u32 so we can pack. const Vec256 mask = Set(dw, 0x0000FFFF); const Vec256 uH = And(BitCast(dw, hi), mask); const Vec256 uL = And(BitCast(dw, lo), mask); const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); return Vec256{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } template HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v2020{_mm256_shuffle_ps( BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; return Vec256{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else const Vec256 v2020{ _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; return BitCast(d, Vec256{_mm256_permute4x64_epi64( BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); #endif } template HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v20{ _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; return Vec256{ _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { #if HWY_TARGET <= HWY_AVX3 const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else (void)d; const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; return Vec256{ _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } // ------------------------------ DupEven (InterleaveLower) template HWY_API Vec256 DupEven(Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; } HWY_API Vec256 DupEven(Vec256 v) { return Vec256{ _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; } template HWY_API Vec256 DupEven(const Vec256 v) { return InterleaveLower(Full256(), v, v); } // ------------------------------ DupOdd (InterleaveUpper) template HWY_API Vec256 DupOdd(Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; } HWY_API Vec256 DupOdd(Vec256 v) { return Vec256{ _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; } template HWY_API Vec256 DupOdd(const Vec256 v) { return InterleaveUpper(Full256(), v, v); } // ------------------------------ OddEven namespace detail { template HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, const Vec256 b) { const Full256 d; const Full256 d8; alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_epi16(a.raw, b.raw, 0x55)}; } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; } } // namespace detail template HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return detail::OddEven(hwy::SizeTag(), a, b); } HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; } HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; } // ------------------------------ OddEvenBlocks template Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { return Vec256{_mm256_blend_epi32(odd.raw, even.raw, 0xFu)}; } HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { return Vec256{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; } HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { return Vec256{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; } // ------------------------------ ReverseBlocks (ConcatLowerUpper) template HWY_API Vec256 ReverseBlocks(Full256 d, Vec256 v) { return ConcatLowerUpper(d, v, v); } // ------------------------------ TableLookupBytes (ZeroExtendVector) // Both full template HWY_API Vec256 TableLookupBytes(const Vec256 bytes, const Vec256 from) { return Vec256{_mm256_shuffle_epi8(bytes.raw, from.raw)}; } // Partial index vector template HWY_API Vec128 TableLookupBytes(const Vec256 bytes, const Vec128 from) { // First expand to full 128, then 256. const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); const auto tbl_full = TableLookupBytes(bytes, from_256); // Shrink to 128, then partial. return Vec128{LowerHalf(Full128(), tbl_full).raw}; } // Partial table vector template HWY_API Vec256 TableLookupBytes(const Vec128 bytes, const Vec256 from) { // First expand to full 128, then 256. const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); return TableLookupBytes(bytes_256, from); } // Partial both are handled by x86_128. // ------------------------------ Shl (Mul, ZipLower) namespace detail { #if HWY_TARGET > HWY_AVX3 && !HWY_IDE // AVX2 or older // Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. template HWY_INLINE Vec256> Pow2(const Vec256 v) { static_assert(sizeof(T) == 2, "Only for 16-bit"); const Full256 d; const RepartitionToWide dw; const Rebind df; const auto zero = Zero(d); // Move into exponent (this u16 will become the upper half of an f32) const auto exp = ShiftLeft<23 - 16>(v); const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f // Insert 0 into lower halves for reinterpreting as binary32. const auto f0 = ZipLower(dw, zero, upper); const auto f1 = ZipUpper(dw, zero, upper); // Do not use ConvertTo because it checks for overflow, which is redundant // because we only care about v in [0, 16). const Vec256 bits0{_mm256_cvttps_epi32(BitCast(df, f0).raw)}; const Vec256 bits1{_mm256_cvttps_epi32(BitCast(df, f1).raw)}; return Vec256>{_mm256_packus_epi32(bits0.raw, bits1.raw)}; } #endif // HWY_TARGET > HWY_AVX3 HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 || HWY_IDE return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; #else return v * Pow2(bits); #endif } HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, Vec256 bits) { return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; } HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, Vec256 bits) { return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; } template HWY_INLINE Vec256 Shl(hwy::SignedTag /*tag*/, Vec256 v, Vec256 bits) { // Signed left shifts are the same as unsigned. const Full256 di; const Full256> du; return BitCast(di, Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); } } // namespace detail template HWY_API Vec256 operator<<(Vec256 v, Vec256 bits) { return detail::Shl(hwy::TypeTag(), v, bits); } // ------------------------------ Shr (MulHigh, IfThenElse, Not) HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 || HWY_IDE return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; #else Full256 d; // For bits=0, we cannot mul by 2^16, so fix the result later. auto out = MulHigh(v, detail::Pow2(Set(d, 16) - bits)); // Replace output with input where bits == 0. return IfThenElse(bits == Zero(d), v, out); #endif } HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; } HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; } HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; #else return detail::SignedShr(Full256(), v, bits); #endif } HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; } HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; #else return detail::SignedShr(Full256(), v, bits); #endif } HWY_INLINE Vec256 MulEven(const Vec256 a, const Vec256 b) { const Full256 du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need the lower 32 bits const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need // the even (lower 64 bits of every 128-bit block) results. See // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveLower(mulL, mulH); } HWY_INLINE Vec256 MulOdd(const Vec256 a, const Vec256 b) { const Full256 du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need bits [95:64] (= upper half of input) const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Same as above, but we're using the odd results (upper 64 bits per block). const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveUpper(du64, mulL, mulH); } // ------------------------------ ReorderWidenMulAccumulate HWY_API Vec256 ReorderWidenMulAccumulate(Full256 /*d32*/, Vec256 a, Vec256 b, const Vec256 sum0, Vec256& /*sum1*/) { return sum0 + Vec256{_mm256_madd_epi16(a.raw, b.raw)}; } // ------------------------------ RearrangeToOddPlusEven HWY_API Vec256 RearrangeToOddPlusEven(const Vec256 sum0, Vec256 /*sum1*/) { return sum0; // invariant already holds } // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{_mm256_cvtps_pd(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{_mm256_cvtepi32_pd(v.raw)}; } // Unsigned: zero-extend. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu16_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu16_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu32_epi64(v.raw)}; } // Signed: replicate sign bit. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by // signed shift would be faster. HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi8_epi16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi8_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi16_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi32_epi64(v.raw)}; } // ------------------------------ Demotions (full -> part w/ narrow lanes) HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); // Concatenating lower halves of both 128-bit blocks afterward is more // efficient than an extra input with low block = high block of v. return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; } HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec256 v) { const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); // Concatenate lower 64 bits of each 128-bit block const __m256i u16_concat = _mm256_permute4x64_epi64(u16_blocks, 0x88); const __m128i u16 = _mm256_castsi256_si128(u16_concat); // packus treats the input as signed; we want unsigned. Clear the MSB to get // unsigned saturation to u8. const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); return Vec128{_mm_packus_epi16(i16, i16)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; } HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec256 v) { const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); // Concatenate lower 64 bits of each 128-bit block const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); const __m128i i16 = _mm256_castsi256_si128(i16_concat); return Vec128{_mm_packs_epi16(i16, i16)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; } // Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". // 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") HWY_API Vec128 DemoteTo(Full128 df16, const Vec256 v) { #ifdef HWY_DISABLE_F16C const RebindToUnsigned du16; const Rebind du; const RebindToSigned di; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); const auto k15 = Set(di, 15); const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); const auto is_tiny = exp < Set(di, -24); const auto is_subnormal = exp < Set(di, -14); const auto biased_exp16 = BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + (mantissa32 >> (Set(du, 13) + sub_exp)); const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, ShiftRight<13>(mantissa32)); // <1024 const auto sign16 = ShiftLeft<15>(sign); const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); return BitCast(df16, DemoteTo(du16, bits16)); #else (void)df16; return Vec128{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; #endif } HWY_DIAGNOSTICS(pop) HWY_API Vec128 DemoteTo(Full128 dbf16, const Vec256 v) { // TODO(janwas): _mm256_cvtneps_pbh once we have avx512bf16. const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } HWY_API Vec256 ReorderDemote2To(Full256 dbf16, Vec256 a, Vec256 b) { // TODO(janwas): _mm256_cvtne2ps_pbh once we have avx512bf16. const RebindToUnsigned du16; const Repartition du32; const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } HWY_API Vec256 ReorderDemote2To(Full256 /*d16*/, Vec256 a, Vec256 b) { return Vec256{_mm256_packs_epi32(a.raw, b.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { return Vec128{_mm256_cvtpd_ps(v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const auto clamped = detail::ClampF64ToI32Max(Full256(), v); return Vec128{_mm256_cvttpd_epi32(clamped.raw)}; } // For already range-limited input [0, 255]. HWY_API Vec128 U8FromU32(const Vec256 v) { const Full256 d32; alignas(32) static constexpr uint32_t k8From32[8] = { 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; // Place first four bytes in lo[0], remaining 4 in hi[1]. const auto quad = TableLookupBytes(v, Load(d32, k8From32)); // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. const auto lo = LowerHalf(quad); const auto hi = UpperHalf(Full128(), quad); const auto pair = LowerHalf(lo | hi); return BitCast(Full64(), pair); } // ------------------------------ Truncations namespace detail { // LO and HI each hold four indices of bytes within a 128-bit block. template HWY_INLINE Vec128 LookupAndConcatHalves(Vec256 v) { const Full256 d32; #if HWY_TARGET <= HWY_AVX3_DL alignas(32) constexpr uint32_t kMap[8] = { LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); #else alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, ~0u, ~0u, LO, HI}; const auto quad = TableLookupBytes(v, Load(d32, kMap)); const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); // Possible alternative: // const auto lo = LowerHalf(quad); // const auto hi = UpperHalf(Full128(), quad); // const auto result = lo | hi; #endif return Vec128{_mm256_castsi256_si128(result)}; } // LO and HI each hold two indices of bytes within a 128-bit block. template HWY_INLINE Vec128 LookupAndConcatQuarters(Vec256 v) { const Full256 d16; #if HWY_TARGET <= HWY_AVX3_DL alignas(32) constexpr uint16_t kMap[16] = { LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; const auto result = _mm256_permutexvar_epi8(v.raw, Load(d16, kMap).raw); return LowerHalf(Vec128{_mm256_castsi256_si128(result)}); #else constexpr uint16_t ff = static_cast(~0u); alignas(32) static constexpr uint16_t kMap[16] = { LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; const auto quad = TableLookupBytes(v, Load(d16, kMap)); const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); const auto half = _mm256_castsi256_si128(mixed); return LowerHalf(Vec128{_mm_packus_epi32(half, half)}); #endif } } // namespace detail HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec256 v) { const Full256 d32; #if HWY_TARGET <= HWY_AVX3_DL alignas(32) constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, 0, 0, 0, 0}; const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); return LowerHalf(LowerHalf(LowerHalf(Vec256{result}))); #else alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, 0x0800FFFFu, ~0u, ~0u, ~0u}; const auto quad = TableLookupBytes(v, Load(d32, kMap)); const auto lo = LowerHalf(quad); const auto hi = UpperHalf(Full128(), quad); const auto result = lo | hi; return LowerHalf(LowerHalf(Vec128{result.raw})); #endif } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec256 v) { const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); return Vec128{result.raw}; } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec256 v) { const Full256 d32; alignas(32) constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto v32 = TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); return LowerHalf(Vec256{v32.raw}); } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec256 v) { const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); return Vec128{full.raw}; } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec256 v) { const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); return Vec128{full.raw}; } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec256 v) { const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); return Vec128{full.raw}; } // ------------------------------ Integer <=> fp (ShiftRight, OddEven) HWY_API Vec256 ConvertTo(Full256 /* tag */, const Vec256 v) { return Vec256{_mm256_cvtepi32_ps(v.raw)}; } HWY_API Vec256 ConvertTo(Full256 dd, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 (void)dd; return Vec256{_mm256_cvtepi64_pd(v.raw)}; #else // Based on wim's approach (https://stackoverflow.com/questions/41144668/) const Repartition d32; const Repartition d64; // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 const auto k84_63 = Set(d64, 0x4530000080000000ULL); const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) const auto k52 = Set(d32, 0x43300000); const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); return (v_upper - k84_63_52) + v_lower; // order matters! #endif } HWY_API Vec256 ConvertTo(HWY_MAYBE_UNUSED Full256 df, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_cvtepu32_ps(v.raw)}; #else // Based on wim's approach (https://stackoverflow.com/questions/34066228/) const RebindToUnsigned du32; const RebindToSigned d32; const auto msk_lo = Set(du32, 0xFFFF); const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 // Extract the 16 lowest/highest significant bits of v and cast to signed int const auto v_lo = BitCast(d32, And(v, msk_lo)); const auto v_hi = BitCast(d32, ShiftRight<16>(v)); return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); #endif } HWY_API Vec256 ConvertTo(HWY_MAYBE_UNUSED Full256 dd, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_cvtepu64_pd(v.raw)}; #else // Based on wim's approach (https://stackoverflow.com/questions/41144668/) const RebindToUnsigned d64; using VU = VFromD; const VU msk_lo = Set(d64, 0xFFFFFFFFULL); const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 // Extract the 32 lowest significant bits of v const VU v_lo = And(v, msk_lo); const VU v_hi = ShiftRight<32>(v); auto uint64_to_double256_fast = [&dd](Vec256 w) HWY_ATTR { w = Or(w, Vec256{ detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); return BitCast(dd, w) - Set(dd, 0x0010000000000000); }; const auto v_lo_dbl = uint64_to_double256_fast(v_lo); return MulAdd(cnst2_32_dbl, uint64_to_double256_fast(v_hi), v_lo_dbl); #endif } // Truncates (rounds toward zero). HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { return detail::FixConversionOverflow(d, v, _mm256_cvttps_epi32(v.raw)); } HWY_API Vec256 ConvertTo(Full256 di, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return detail::FixConversionOverflow(di, v, _mm256_cvttpd_epi64(v.raw)); #else using VI = decltype(Zero(di)); const VI k0 = Zero(di); const VI k1 = Set(di, 1); const VI k51 = Set(di, 51); // Exponent indicates whether the number can be represented as int64_t. const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); const VI exp = biased_exp - Set(di, 0x3FF); const auto in_range = exp < Set(di, 63); // If we were to cap the exponent at 51 and add 2^52, the number would be in // [2^52, 2^53) and mantissa bits could be read out directly. We need to // round-to-0 (truncate), but changing rounding mode in MXCSR hits a // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead // manually shift the mantissa into place (we already have many of the // inputs anyway). const VI shift_mnt = Max(k51 - exp, k0); const VI shift_int = Max(exp - k51, k0); const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); // For inputs larger than 2^52, insert zeros at the bottom. const VI shifted = int52 << shift_int; // Restore the one bit lost when shifting in the implicit 1-bit. const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. const VI sign_mask = BroadcastSignBit(BitCast(di, v)); const VI limit = Set(di, LimitsMax()) - sign_mask; const VI magnitude = IfThenElse(in_range, restored, limit); // If the input was negative, negate the integer (two's complement). return (magnitude ^ sign_mask) - sign_mask; #endif } HWY_API Vec256 NearestInt(const Vec256 v) { const Full256 di; return detail::FixConversionOverflow(di, v, _mm256_cvtps_epi32(v.raw)); } HWY_API Vec256 PromoteTo(Full256 df32, const Vec128 v) { #ifdef HWY_DISABLE_F16C const RebindToSigned di32; const RebindToUnsigned du32; // Expand to u32 so we can shift. const auto bits16 = PromoteTo(du32, Vec128{v.raw}); const auto sign = ShiftRight<15>(bits16); const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); const auto mantissa = bits16 & Set(du32, 0x3FF); const auto subnormal = BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * Set(df32, 1.0f / 16384 / 1024)); const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); return BitCast(df32, ShiftLeft<31>(sign) | bits32); #else (void)df32; return Vec256{_mm256_cvtph_ps(v.raw)}; #endif } HWY_API Vec256 PromoteTo(Full256 df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } // ================================================== CRYPTO #if !defined(HWY_DISABLE_PCLMUL_AES) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES #undef HWY_NATIVE_AES #else #define HWY_NATIVE_AES #endif HWY_API Vec256 AESRound(Vec256 state, Vec256 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_aesenc_epi128(state.raw, round_key.raw)}; #else const Full256 d; const Half d2; return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec256 AESLastRound(Vec256 state, Vec256 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; #else const Full256 d; const Half d2; return Combine(d, AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESLastRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; #else const Full256 d; const Half d2; return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), CLMulLower(LowerHalf(a), LowerHalf(b))); #endif } HWY_API Vec256 CLMulUpper(Vec256 a, Vec256 b) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; #else const Full256 d; const Half d2; return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), CLMulUpper(LowerHalf(a), LowerHalf(b))); #endif } #endif // HWY_DISABLE_PCLMUL_AES // ================================================== MISC // Returns a vector with lane i=[0, N) set to "first" + i. template HWY_API Vec256 Iota(const Full256 d, const T2 first) { HWY_ALIGN T lanes[32 / sizeof(T)]; for (size_t i = 0; i < 32 / sizeof(T); ++i) { lanes[i] = AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); } return Load(d, lanes); } #if HWY_TARGET <= HWY_AVX3 // ------------------------------ LoadMaskBits // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask256 LoadMaskBits(const Full256 /* tag */, const uint8_t* HWY_RESTRICT bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return Mask256::FromBits(mask_bits); } // ------------------------------ StoreMaskBits // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, uint8_t* bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(&mask.raw, bits); // Non-full byte, need to clear the undefined upper bits. if (N < 8) { const int mask_bits = static_cast((1ull << N) - 1); bits[0] = static_cast(bits[0] & mask_bits); } return kNumBytes; } // ------------------------------ Mask testing template HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { return PopCount(static_cast(mask.raw)); } template HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, const Mask256 mask) { return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); } template HWY_API intptr_t FindFirstTrue(const Full256 d, const Mask256 mask) { return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) : intptr_t{-1}; } // Beware: the suffix indicates the number of mask bits, not lane size! namespace detail { template HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { return (uint64_t{mask.raw} & 0xF) == 0; } } // namespace detail template HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { return detail::AllFalse(hwy::SizeTag(), mask); } namespace detail { template HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFFFFFu; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFu; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFu; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { // Cannot use _kortestc because we have less than 8 mask bits. return mask.raw == 0xFu; } } // namespace detail template HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { return detail::AllTrue(hwy::SizeTag(), mask); } // ------------------------------ Compress // 16-bit is defined in x86_512 so we can use 512-bit vectors. template HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; } HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_ps(mask.raw, v.raw)}; } template HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { // See CompressIsPartition. alignas(16) constexpr uint64_t packed_array[16] = { // PrintCompress64x4NibbleTables 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, 0x00001032, 0x00001320, 0x00000321, 0x00003210}; // For lane i, shift the i-th 4-bit index down to bits [0, 2) - // _mm256_permutexvar_epi64 will ignore the upper bits. const Full256 d; const RebindToUnsigned du64; const auto packed = Set(du64, packed_array[mask.raw]); alignas(64) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; return TableLookupLanes(v, indices); } // ------------------------------ CompressNot (Compress) // Implemented in x86_512 for lane size != 8. template HWY_API Vec256 CompressNot(Vec256 v, Mask256 mask) { // See CompressIsPartition. alignas(16) constexpr uint64_t packed_array[16] = { // PrintCompressNot64x4NibbleTables 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, 0x00003210, 0x00003201, 0x00003210, 0x00003210}; // For lane i, shift the i-th 4-bit index down to bits [0, 2) - // _mm256_permutexvar_epi64 will ignore the upper bits. const Full256 d; const RebindToUnsigned du64; const auto packed = Set(du64, packed_array[mask.raw]); alignas(32) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; return TableLookupLanes(v, indices); } // ------------------------------ CompressStore // 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. template HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, T* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); const size_t count = PopCount(uint64_t{mask.raw}); detail::MaybeUnpoison(unaligned, count); return count; } template HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, T* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); detail::MaybeUnpoison(unaligned, count); return count; } HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, float* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); const size_t count = PopCount(uint64_t{mask.raw}); detail::MaybeUnpoison(unaligned, count); return count; } HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, double* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); detail::MaybeUnpoison(unaligned, count); return count; } // ------------------------------ CompressBlendedStore (CompressStore) template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) > 2) { // Native (32 or 64-bit) AVX-512 instruction already does the blending at no // extra cost (latency 11, rthroughput 2 - same as compress plus store). return CompressStore(v, m, d, unaligned); } else { const size_t count = CountTrue(d, m); BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); detail::MaybeUnpoison(unaligned, count); return count; } } // ------------------------------ CompressBitsStore (LoadMaskBits) template HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, Full256 d, T* HWY_RESTRICT unaligned) { return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); } #else // AVX2 // ------------------------------ LoadMaskBits (TestBit) namespace detail { // 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_LE128 there. template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; const Repartition du32; const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); // Replicate bytes 8x such that each byte contains the bit that governs it. const Repartition du64; alignas(32) constexpr uint64_t kRep8[4] = { 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, 0x0303030303030303ull}; const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(32) constexpr uint16_t kBit[16] = { 1, 2, 4, 8, 16, 32, 64, 128, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; const auto vmask_bits = Set(du, static_cast(mask_bits)); return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; const auto vmask_bits = Set(du, static_cast(mask_bits)); return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(32) constexpr uint64_t kBit[8] = {1, 2, 4, 8}; return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); } } // namespace detail // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask256 LoadMaskBits(Full256 d, const uint8_t* HWY_RESTRICT bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return detail::LoadMaskBits256(d, mask_bits); } // ------------------------------ StoreMaskBits namespace detail { template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { const Full256 d; const Full256 d8; const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; // Prevent sign-extension of 32-bit masks because the intrinsic returns int. return static_cast(_mm256_movemask_epi8(sign_bits)); } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { #if HWY_ARCH_X86_64 const Full256 d; const Full256 d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); const uint64_t sign_bits8 = BitsFromMask(mask8); // Skip the bits from the lower byte of each u16 (better not to use the // same packs_epi16 as SSE4, because that requires an extra swizzle here). return _pext_u64(sign_bits8, 0xAAAAAAAAull); #else // Slow workaround for 32-bit builds, which lack _pext_u64. // Remove useless lower half of each u16 while preserving the sign bit. // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); // Move odd qwords (value zero) to top so they don't affect the mask value. const auto compressed = _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0)); return static_cast(_mm256_movemask_epi8(compressed)); #endif // HWY_ARCH_X86_64 } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { const Full256 d; const Full256 df; const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; return static_cast(_mm256_movemask_ps(sign_bits)); } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { const Full256 d; const Full256 df; const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; return static_cast(_mm256_movemask_pd(sign_bits)); } } // namespace detail // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, uint8_t* bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; const uint64_t mask_bits = detail::BitsFromMask(mask); CopyBytes(&mask_bits, bits); return kNumBytes; } // ------------------------------ Mask testing // Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask // lane is 0 or ~0. template HWY_API bool AllFalse(const Full256 d, const Mask256 mask) { const Repartition d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); return detail::BitsFromMask(mask8) == 0; } template HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { // Cheaper than PTEST, which is 2 uop / 3L. return detail::BitsFromMask(mask) == 0; } template HWY_API bool AllTrue(const Full256 d, const Mask256 mask) { const Repartition d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); return detail::BitsFromMask(mask8) == (1ull << 32) - 1; } template HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { constexpr uint64_t kAllBits = (1ull << (32 / sizeof(T))) - 1; return detail::BitsFromMask(mask) == kAllBits; } template HWY_API size_t CountTrue(const Full256 d, const Mask256 mask) { const Repartition d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); return PopCount(detail::BitsFromMask(mask8)) >> 1; } template HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { return PopCount(detail::BitsFromMask(mask)); } template HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, const Mask256 mask) { const uint64_t mask_bits = detail::BitsFromMask(mask); return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); } template HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, const Mask256 mask) { const uint64_t mask_bits = detail::BitsFromMask(mask); return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; } // ------------------------------ Compress, CompressBits namespace detail { template HWY_INLINE Vec256 IndicesFromBits(Full256 d, uint64_t mask_bits) { const RebindToUnsigned d32; // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT // of SetTableIndices would require 8 KiB, a large part of L1D. The other // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) // and unavailable in 32-bit builds. We instead compress each index into 4 // bits, for a total of 1 KiB. alignas(16) constexpr uint32_t packed_array[256] = { // PrintCompress32x8Tables 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing // latency, it may be faster to use LoadDup128 and PSHUFB. const auto packed = Set(d32, packed_array[mask_bits]); alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; return packed >> Load(d32, shifts); } template HWY_INLINE Vec256 IndicesFromBits(Full256 d, uint64_t mask_bits) { const Repartition d32; // For 64-bit, we still need 32-bit indices because there is no 64-bit // permutevar, but there are only 4 lanes, so we can afford to skip the // unpacking and load the entire index vector directly. alignas(32) constexpr uint32_t u32_indices[128] = { // PrintCompress64x4PairTables 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; return Load(d32, u32_indices + 8 * mask_bits); } template HWY_INLINE Vec256 IndicesFromNotBits(Full256 d, uint64_t mask_bits) { const RebindToUnsigned d32; // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT // of SetTableIndices would require 8 KiB, a large part of L1D. The other // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) // and unavailable in 32-bit builds. We instead compress each index into 4 // bits, for a total of 1 KiB. alignas(16) constexpr uint32_t packed_array[256] = { // PrintCompressNot32x8Tables 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing // latency, it may be faster to use LoadDup128 and PSHUFB. const auto packed = Set(d32, packed_array[mask_bits]); alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; return packed >> Load(d32, shifts); } template HWY_INLINE Vec256 IndicesFromNotBits(Full256 d, uint64_t mask_bits) { const Repartition d32; // For 64-bit, we still need 32-bit indices because there is no 64-bit // permutevar, but there are only 4 lanes, so we can afford to skip the // unpacking and load the entire index vector directly. alignas(32) constexpr uint32_t u32_indices[128] = { // PrintCompressNot64x4PairTables 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; return Load(d32, u32_indices + 8 * mask_bits); } template HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { const Full256 d; const Repartition du32; HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is // no instruction for 4x64). const Indices256 indices{IndicesFromBits(d, mask_bits).raw}; return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); } // LUTs are infeasible for 2^16 possible masks, so splice together two // half-vector Compress. template HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { const Full256 d; const RebindToUnsigned du; const auto vu16 = BitCast(du, v); // (required for float16_t inputs) const Half duh; const auto half0 = LowerHalf(duh, vu16); const auto half1 = UpperHalf(duh, vu16); const uint64_t mask_bits0 = mask_bits & 0xFF; const uint64_t mask_bits1 = mask_bits >> 8; const auto compressed0 = detail::CompressBits(half0, mask_bits0); const auto compressed1 = detail::CompressBits(half1, mask_bits1); alignas(32) uint16_t all_true[16] = {}; // Store mask=true lanes, left to right. const size_t num_true0 = PopCount(mask_bits0); Store(compressed0, duh, all_true); StoreU(compressed1, duh, all_true + num_true0); if (hwy::HWY_NAMESPACE::CompressIsPartition::value) { // Store mask=false lanes, right to left. The second vector fills the upper // half with right-aligned false lanes. The first vector is shifted // rightwards to overwrite the true lanes of the second. alignas(32) uint16_t all_false[16] = {}; const size_t num_true1 = PopCount(mask_bits1); Store(compressed1, duh, all_false + 8); StoreU(compressed0, duh, all_false + num_true1); const auto mask = FirstN(du, num_true0 + num_true1); return BitCast(d, IfThenElse(mask, Load(du, all_true), Load(du, all_false))); } else { // Only care about the mask=true lanes. return BitCast(d, Load(du, all_true)); } } template // 4 or 8 bytes HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { const Full256 d; const Repartition du32; HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is // no instruction for 4x64). const Indices256 indices{IndicesFromNotBits(d, mask_bits).raw}; return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); } // LUTs are infeasible for 2^16 possible masks, so splice together two // half-vector Compress. template HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { // Compress ensures only the lower 16 bits are set, so flip those. return Compress(v, mask_bits ^ 0xFFFF); } } // namespace detail template HWY_API Vec256 Compress(Vec256 v, Mask256 m) { return detail::Compress(v, detail::BitsFromMask(m)); } template HWY_API Vec256 CompressNot(Vec256 v, Mask256 m) { return detail::CompressNot(v, detail::BitsFromMask(m)); } HWY_API Vec256 CompressBlocksNot(Vec256 v, Mask256 mask) { return CompressNot(v, mask); } template HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return detail::Compress(v, mask_bits); } // ------------------------------ CompressStore, CompressBitsStore template HWY_API size_t CompressStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); StoreU(detail::Compress(v, mask_bits), d, unaligned); detail::MaybeUnpoison(unaligned, count); return count; } template // 4 or 8 bytes HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); const Repartition du32; HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is // no instruction for 4x64). Nibble MSB encodes FirstN. const Vec256 idx_and_mask = detail::IndicesFromBits(d, mask_bits); // Shift nibble MSB into MSB const Mask256 mask32 = MaskFromVec(ShiftLeft<28>(idx_and_mask)); // First cast to unsigned (RebindMask cannot change lane size) const Mask256> mask_u{mask32.raw}; const Mask256 mask = RebindMask(d, mask_u); const Vec256 compressed = BitCast(d, TableLookupLanes(BitCast(du32, v), Indices256{idx_and_mask.raw})); BlendedStore(compressed, mask, d, unaligned); detail::MaybeUnpoison(unaligned, count); return count; } template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); const Vec256 compressed = detail::Compress(v, mask_bits); #if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN // BlendedStore tests mask for each lane, but we know that the mask is // FirstN, so we can just copy. alignas(32) T buf[16]; Store(compressed, d, buf); memcpy(unaligned, buf, count * sizeof(T)); #else BlendedStore(compressed, FirstN(d, count), d, unaligned); #endif return count; } template HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, Full256 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } const size_t count = PopCount(mask_bits); StoreU(detail::Compress(v, mask_bits), d, unaligned); detail::MaybeUnpoison(unaligned, count); return count; } #endif // HWY_TARGET <= HWY_AVX3 // ------------------------------ LoadInterleaved3/4 // Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. namespace detail { // Input: // 1 0 (<- first block of unaligned) // 3 2 // 5 4 // Output: // 3 0 // 4 1 // 5 2 template HWY_API void LoadTransposedBlocks3(Full256 d, const T* HWY_RESTRICT unaligned, Vec256& A, Vec256& B, Vec256& C) { constexpr size_t N = 32 / sizeof(T); const Vec256 v10 = LoadU(d, unaligned + 0 * N); // 1 0 const Vec256 v32 = LoadU(d, unaligned + 1 * N); const Vec256 v54 = LoadU(d, unaligned + 2 * N); A = ConcatUpperLower(d, v32, v10); B = ConcatLowerUpper(d, v54, v10); C = ConcatUpperLower(d, v54, v32); } // Input (128-bit blocks): // 1 0 (first block of unaligned) // 3 2 // 5 4 // 7 6 // Output: // 4 0 (LSB of A) // 5 1 // 6 2 // 7 3 template HWY_API void LoadTransposedBlocks4(Full256 d, const T* HWY_RESTRICT unaligned, Vec256& A, Vec256& B, Vec256& C, Vec256& D) { constexpr size_t N = 32 / sizeof(T); const Vec256 v10 = LoadU(d, unaligned + 0 * N); const Vec256 v32 = LoadU(d, unaligned + 1 * N); const Vec256 v54 = LoadU(d, unaligned + 2 * N); const Vec256 v76 = LoadU(d, unaligned + 3 * N); A = ConcatLowerLower(d, v54, v10); B = ConcatUpperUpper(d, v54, v10); C = ConcatLowerLower(d, v76, v32); D = ConcatUpperUpper(d, v76, v32); } } // namespace detail // ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. namespace detail { // Input (128-bit blocks): // 2 0 (LSB of i) // 3 1 // Output: // 1 0 // 3 2 template HWY_API void StoreTransposedBlocks2(const Vec256 i, const Vec256 j, const Full256 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 32 / sizeof(T); const auto out0 = ConcatLowerLower(d, j, i); const auto out1 = ConcatUpperUpper(d, j, i); StoreU(out0, d, unaligned + 0 * N); StoreU(out1, d, unaligned + 1 * N); } // Input (128-bit blocks): // 3 0 (LSB of i) // 4 1 // 5 2 // Output: // 1 0 // 3 2 // 5 4 template HWY_API void StoreTransposedBlocks3(const Vec256 i, const Vec256 j, const Vec256 k, Full256 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 32 / sizeof(T); const auto out0 = ConcatLowerLower(d, j, i); const auto out1 = ConcatUpperLower(d, i, k); const auto out2 = ConcatUpperUpper(d, k, j); StoreU(out0, d, unaligned + 0 * N); StoreU(out1, d, unaligned + 1 * N); StoreU(out2, d, unaligned + 2 * N); } // Input (128-bit blocks): // 4 0 (LSB of i) // 5 1 // 6 2 // 7 3 // Output: // 1 0 // 3 2 // 5 4 // 7 6 template HWY_API void StoreTransposedBlocks4(const Vec256 i, const Vec256 j, const Vec256 k, const Vec256 l, Full256 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 32 / sizeof(T); // Write lower halves, then upper. const auto out0 = ConcatLowerLower(d, j, i); const auto out1 = ConcatLowerLower(d, l, k); StoreU(out0, d, unaligned + 0 * N); StoreU(out1, d, unaligned + 1 * N); const auto out2 = ConcatUpperUpper(d, j, i); const auto out3 = ConcatUpperUpper(d, l, k); StoreU(out2, d, unaligned + 2 * N); StoreU(out3, d, unaligned + 3 * N); } } // namespace detail // ------------------------------ Reductions namespace detail { // Returns sum{lane[i]} in each lane. "v3210" is a replicated 128-bit block. // Same logic as x86/128.h, but with Vec256 arguments. template HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const auto v1032 = Shuffle1032(v3210); const auto v31_20_31_20 = v3210 + v1032; const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); return v20_31_20_31 + v31_20_31_20; } template HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const auto v1032 = Shuffle1032(v3210); const auto v31_20_31_20 = Min(v3210, v1032); const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); return Min(v20_31_20_31, v31_20_31_20); } template HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const auto v1032 = Shuffle1032(v3210); const auto v31_20_31_20 = Max(v3210, v1032); const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); return Max(v20_31_20_31, v31_20_31_20); } template HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const auto v01 = Shuffle01(v10); return v10 + v01; } template HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const auto v01 = Shuffle01(v10); return Min(v10, v01); } template HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const auto v01 = Shuffle01(v10); return Max(v10, v01); } HWY_API Vec256 SumOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Full256 d; const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); } HWY_API Vec256 SumOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Full256 d; const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); } HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Full256 d; const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Full256 d; const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Full256 d; const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Full256 d; const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } } // namespace detail // Supported for {uif}{32,64},{ui}16. Returns the broadcasted result. template HWY_API Vec256 SumOfLanes(Full256 d, const Vec256 vHL) { const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); return detail::SumOfLanes(hwy::SizeTag(), vLH + vHL); } template HWY_API Vec256 MinOfLanes(Full256 d, const Vec256 vHL) { const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); return detail::MinOfLanes(hwy::SizeTag(), Min(vLH, vHL)); } template HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 vHL) { const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); return detail::MaxOfLanes(hwy::SizeTag(), Max(vLH, vHL)); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); // Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - // the warning seems to be issued at the call site of intrinsics, i.e. our code. HWY_DIAGNOSTICS(pop)