// Copyright 2019 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 512-bit AVX512 vectors and operations. // External include guard in highway.h - see comment there. // WARNING: most operations do not cross 128-bit block boundaries. In // particular, "Broadcast", pack and zip behavior may be surprising. // Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL #include "hwy/base.h" // Avoid uninitialized warnings in GCC's avx512fintrin.h - see // https://github.com/google/highway/issues/710) HWY_DIAGNOSTICS(push) #if HWY_COMPILER_GCC_ACTUAL HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") #endif #include // AVX2+ #if HWY_COMPILER_CLANGCL // Including should be enough, but Clang's headers helpfully skip // including these headers when _MSC_VER is defined, like when using clang-cl. // Include these directly here. // clang-format off #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // clang-format on #endif // HWY_COMPILER_CLANGCL #include #include #if HWY_IS_MSAN #include #endif // For half-width vectors. Already includes base.h and shared-inl.h. #include "hwy/ops/x86_256-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { namespace detail { template struct Raw512 { using type = __m512i; }; template <> struct Raw512 { using type = __m512; }; template <> struct Raw512 { using type = __m512d; }; // Template arg: sizeof(lane type) template struct RawMask512 {}; template <> struct RawMask512<1> { using type = __mmask64; }; template <> struct RawMask512<2> { using type = __mmask32; }; template <> struct RawMask512<4> { using type = __mmask16; }; template <> struct RawMask512<8> { using type = __mmask8; }; } // namespace detail template class Vec512 { using Raw = typename detail::Raw512::type; public: using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec512& operator*=(const Vec512 other) { return *this = (*this * other); } HWY_INLINE Vec512& operator/=(const Vec512 other) { return *this = (*this / other); } HWY_INLINE Vec512& operator+=(const Vec512 other) { return *this = (*this + other); } HWY_INLINE Vec512& operator-=(const Vec512 other) { return *this = (*this - other); } HWY_INLINE Vec512& operator&=(const Vec512 other) { return *this = (*this & other); } HWY_INLINE Vec512& operator|=(const Vec512 other) { return *this = (*this | other); } HWY_INLINE Vec512& operator^=(const Vec512 other) { return *this = (*this ^ other); } Raw raw; }; // Mask register: one bit per lane. template struct Mask512 { using Raw = typename detail::RawMask512::type; Raw raw; }; template using Full512 = Simd; // ------------------------------ BitCast namespace detail { HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } HWY_INLINE __m512i BitCastToInteger(__m512d v) { return _mm512_castpd_si512(v); } template HWY_INLINE Vec512 BitCastToByte(Vec512 v) { return Vec512{BitCastToInteger(v.raw)}; } // Cannot rely on function overloading because return types differ. template struct BitCastFromInteger512 { HWY_INLINE __m512i operator()(__m512i v) { return v; } }; template <> struct BitCastFromInteger512 { HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } }; template <> struct BitCastFromInteger512 { HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } }; template HWY_INLINE Vec512 BitCastFromByte(Full512 /* tag */, Vec512 v) { return Vec512{BitCastFromInteger512()(v.raw)}; } } // namespace detail template HWY_API Vec512 BitCast(Full512 d, Vec512 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } // ------------------------------ Set // Returns an all-zero vector. template HWY_API Vec512 Zero(Full512 /* tag */) { return Vec512{_mm512_setzero_si512()}; } HWY_API Vec512 Zero(Full512 /* tag */) { return Vec512{_mm512_setzero_ps()}; } HWY_API Vec512 Zero(Full512 /* tag */) { return Vec512{_mm512_setzero_pd()}; } // Returns a vector with all lanes set to "t". HWY_API Vec512 Set(Full512 /* tag */, const uint8_t t) { return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const uint16_t t) { return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const uint32_t t) { return Vec512{_mm512_set1_epi32(static_cast(t))}; } HWY_API Vec512 Set(Full512 /* tag */, const uint64_t t) { return Vec512{ _mm512_set1_epi64(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const int8_t t) { return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const int16_t t) { return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const int32_t t) { return Vec512{_mm512_set1_epi32(t)}; } HWY_API Vec512 Set(Full512 /* tag */, const int64_t t) { return Vec512{ _mm512_set1_epi64(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const float t) { return Vec512{_mm512_set1_ps(t)}; } HWY_API Vec512 Set(Full512 /* tag */, const double t) { return Vec512{_mm512_set1_pd(t)}; } HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template HWY_API Vec512 Undefined(Full512 /* tag */) { // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC // generate an XOR instruction. return Vec512{_mm512_undefined_epi32()}; } HWY_API Vec512 Undefined(Full512 /* tag */) { return Vec512{_mm512_undefined_ps()}; } HWY_API Vec512 Undefined(Full512 /* tag */) { return Vec512{_mm512_undefined_pd()}; } HWY_DIAGNOSTICS(pop) // ================================================== LOGICAL // ------------------------------ Not template HWY_API Vec512 Not(const Vec512 v) { using TU = MakeUnsigned; const __m512i vu = BitCast(Full512(), v).raw; return BitCast(Full512(), Vec512{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); } // ------------------------------ And template HWY_API Vec512 And(const Vec512 a, const Vec512 b) { return Vec512{_mm512_and_si512(a.raw, b.raw)}; } HWY_API Vec512 And(const Vec512 a, const Vec512 b) { return Vec512{_mm512_and_ps(a.raw, b.raw)}; } HWY_API Vec512 And(const Vec512 a, const Vec512 b) { return Vec512{_mm512_and_pd(a.raw, b.raw)}; } // ------------------------------ AndNot // Returns ~not_mask & mask. template HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { return Vec512{_mm512_andnot_si512(not_mask.raw, mask.raw)}; } HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; } HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; } // ------------------------------ Or template HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { return Vec512{_mm512_or_si512(a.raw, b.raw)}; } HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { return Vec512{_mm512_or_ps(a.raw, b.raw)}; } HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { return Vec512{_mm512_or_pd(a.raw, b.raw)}; } // ------------------------------ Xor template HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_si512(a.raw, b.raw)}; } HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_ps(a.raw, b.raw)}; } HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_pd(a.raw, b.raw)}; } // ------------------------------ Xor3 template HWY_API Vec512 Xor3(Vec512 x1, Vec512 x2, Vec512 x3) { const Full512 d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); return BitCast(d, VU{ret}); } // ------------------------------ Or3 template HWY_API Vec512 Or3(Vec512 o1, Vec512 o2, Vec512 o3) { const Full512 d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); return BitCast(d, VU{ret}); } // ------------------------------ OrAnd template HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { const Full512 d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); return BitCast(d, VU{ret}); } // ------------------------------ IfVecThenElse template HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { const Full512 d; const RebindToUnsigned du; using VU = VFromD; return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, BitCast(du, no).raw, 0xCA)}); } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { return And(a, b); } template HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { return Or(a, b); } template HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { return Xor(a, b); } // ------------------------------ PopulationCount // 8/16 require BITALG, 32/64 require VPOPCNTDQ. #if HWY_TARGET == HWY_AVX3_DL #ifdef HWY_NATIVE_POPCNT #undef HWY_NATIVE_POPCNT #else #define HWY_NATIVE_POPCNT #endif namespace detail { template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<1> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi8(v.raw)}; } template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<2> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi16(v.raw)}; } template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<4> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi32(v.raw)}; } template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<8> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi64(v.raw)}; } } // namespace detail template HWY_API Vec512 PopulationCount(Vec512 v) { return detail::PopulationCount(hwy::SizeTag(), v); } #endif // HWY_TARGET == HWY_AVX3_DL // ================================================== SIGN // ------------------------------ CopySign template HWY_API Vec512 CopySign(const Vec512 magn, const Vec512 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const Full512 d; const auto msb = SignBit(d); const Rebind, decltype(d)> du; // Truth table for msb, magn, sign | bitwise msb ? sign : mag // 0 0 0 | 0 // 0 0 1 | 0 // 0 1 0 | 1 // 0 1 1 | 1 // 1 0 0 | 0 // 1 0 1 | 1 // 1 1 0 | 0 // 1 1 1 | 1 // The lane size does not matter because we are not using predication. const __m512i out = _mm512_ternarylogic_epi32( BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); return BitCast(d, decltype(Zero(du)){out}); } template HWY_API Vec512 CopySignToAbs(const Vec512 abs, const Vec512 sign) { // AVX3 can also handle abs < 0, so no extra action needed. return CopySign(abs, sign); } // ================================================== MASK // ------------------------------ FirstN // Possibilities for constructing a bitmask of N ones: // - kshift* only consider the lowest byte of the shift count, so they would // not correctly handle large n. // - Scalar shifts >= 64 are UB. // - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, // we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. #if HWY_ARCH_X86_32 namespace detail { // 32 bit mask is sufficient for lane size >= 2. template HWY_INLINE Mask512 FirstN(size_t n) { Mask512 m; const uint32_t all = ~uint32_t{0}; // BZHI only looks at the lower 8 bits of n! m.raw = static_cast((n > 255) ? all : _bzhi_u32(all, n)); return m; } template HWY_INLINE Mask512 FirstN(size_t n) { const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; return Mask512{static_cast<__mmask64>(bits)}; } } // namespace detail #endif // HWY_ARCH_X86_32 template HWY_API Mask512 FirstN(const Full512 /*tag*/, size_t n) { #if HWY_ARCH_X86_64 Mask512 m; const uint64_t all = ~uint64_t{0}; // BZHI only looks at the lower 8 bits of n! m.raw = static_cast((n > 255) ? all : _bzhi_u64(all, n)); return m; #else return detail::FirstN(n); #endif // HWY_ARCH_X86_64 } // ------------------------------ IfThenElse // Returns mask ? b : a. namespace detail { // Templates for signed/unsigned integer of a particular size. template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); } HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; } HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; } HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, const Mask512 mask, const Vec512 no) { // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; } } // namespace detail template HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; } HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { static_assert(IsSigned(), "Only works for signed/float"); // AVX3 MaskFromVec only looks at the MSB return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec512 ZeroIfNegative(const Vec512 v) { // AVX3 MaskFromVec only looks at the MSB return IfThenZeroElse(MaskFromVec(v), v); } // ================================================== ARITHMETIC // ------------------------------ Addition // Unsigned HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_ps(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_pd(a.raw, b.raw)}; } // ------------------------------ Subtraction // Unsigned HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_ps(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_pd(a.raw, b.raw)}; } // ------------------------------ SumsOf8 HWY_API Vec512 SumsOf8(const Vec512 v) { return Vec512{_mm512_sad_epu8(v.raw, _mm512_setzero_si512())}; } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; } // ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. // Unsigned HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; } // ------------------------------ Average // Returns (a + b + 1) / 2 // Unsigned HWY_API Vec512 AverageRound(const Vec512 a, const Vec512 b) { return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; } HWY_API Vec512 AverageRound(const Vec512 a, const Vec512 b) { return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; } // ------------------------------ Abs (Sub) // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. HWY_API Vec512 Abs(const Vec512 v) { #if HWY_COMPILER_MSVC // Workaround for incorrect codegen? (untested due to internal compiler error) const auto zero = Zero(Full512()); return Vec512{_mm512_max_epi8(v.raw, (zero - v).raw)}; #else return Vec512{_mm512_abs_epi8(v.raw)}; #endif } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi16(v.raw)}; } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi32(v.raw)}; } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi64(v.raw)}; } // These aren't native instructions, they also involve AND with constant. HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_ps(v.raw)}; } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_pd(v.raw)}; } // ------------------------------ ShiftLeft template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { const Full512 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); return kBits == 1 ? (v + v) : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } // ------------------------------ ShiftRight template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srli_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srli_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srli_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { const Full512 d8; // Use raw instead of BitCast to support N=1. const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; return shifted & Set(d8, 0xFF >> kBits); } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srai_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srai_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srai_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { const Full512 di; const Full512 du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ RotateRight template HWY_API Vec512 RotateRight(const Vec512 v) { static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); return Vec512{_mm512_ror_epi32(v.raw, kBits)}; } template HWY_API Vec512 RotateRight(const Vec512 v) { static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); return Vec512{_mm512_ror_epi64(v.raw, kBits)}; } // ------------------------------ ShiftLeftSame HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } template HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { const Full512 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); } // ------------------------------ ShiftRightSame HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { const Full512 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast(0xFF >> bits)); } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { const Full512 di; const Full512 du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, static_cast(0x80 >> bits))); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ Shl HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; } HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; } HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; } // Signed left shift is the same as unsigned. template HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { const Full512 di; const Full512> du; return BitCast(di, BitCast(du, v) << BitCast(du, bits)); } // ------------------------------ Shr HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; } // ------------------------------ Minimum // Unsigned HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu8(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu16(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu32(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu64(a.raw, b.raw)}; } // Signed HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi8(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi16(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi32(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_ps(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_pd(a.raw, b.raw)}; } // ------------------------------ Maximum // Unsigned HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu8(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu16(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu32(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu64(a.raw, b.raw)}; } // Signed HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi8(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi16(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi32(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_ps(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_pd(a.raw, b.raw)}; } // ------------------------------ Integer multiplication // Unsigned HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; } HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; } HWY_API Vec128 operator*(Vec128 a, Vec128 b) { return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; } // Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. #ifdef HWY_NATIVE_I64MULLO #undef HWY_NATIVE_I64MULLO #else #define HWY_NATIVE_I64MULLO #endif // Signed HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; } HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; } HWY_API Vec128 operator*(Vec128 a, Vec128 b) { return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; } HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; } HWY_API Vec512 MulFixedPoint15(Vec512 a, Vec512 b) { return Vec512{_mm512_mulhrs_epi16(a.raw, b.raw)}; } // Multiplies even lanes (0, 2 ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; } HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; } // ------------------------------ Neg (Sub) template HWY_API Vec512 Neg(const Vec512 v) { return Xor(v, SignBit(Full512())); } template HWY_API Vec512 Neg(const Vec512 v) { return Zero(Full512()) - v; } // ------------------------------ Floating-point mul / div HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mul_ps(a.raw, b.raw)}; } HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mul_pd(a.raw, b.raw)}; } HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { return Vec512{_mm512_div_ps(a.raw, b.raw)}; } HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { return Vec512{_mm512_div_pd(a.raw, b.raw)}; } // Approximate reciprocal HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { return Vec512{_mm512_rcp14_ps(v.raw)}; } // Absolute value of difference. HWY_API Vec512 AbsDiff(const Vec512 a, const Vec512 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants // Returns mul * x + add HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; } HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; } // Returns add - mul * x HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; } HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; } // Returns mul * x - sub HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; } HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; } // Returns -mul * x - sub HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; } HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; } // ------------------------------ Floating-point square root // Full precision square root HWY_API Vec512 Sqrt(const Vec512 v) { return Vec512{_mm512_sqrt_ps(v.raw)}; } HWY_API Vec512 Sqrt(const Vec512 v) { return Vec512{_mm512_sqrt_pd(v.raw)}; } // Approximate reciprocal square root HWY_API Vec512 ApproximateReciprocalSqrt(const Vec512 v) { return Vec512{_mm512_rsqrt14_ps(v.raw)}; } // ------------------------------ Floating-point rounding // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") // Toward nearest integer, tie to even HWY_API Vec512 Round(const Vec512 v) { return Vec512{_mm512_roundscale_ps( v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Round(const Vec512 v) { return Vec512{_mm512_roundscale_pd( v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } // Toward zero, aka truncate HWY_API Vec512 Trunc(const Vec512 v) { return Vec512{ _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Trunc(const Vec512 v) { return Vec512{ _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } // Toward +infinity, aka ceiling HWY_API Vec512 Ceil(const Vec512 v) { return Vec512{ _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Ceil(const Vec512 v) { return Vec512{ _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } // Toward -infinity, aka floor HWY_API Vec512 Floor(const Vec512 v) { return Vec512{ _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Floor(const Vec512 v) { return Vec512{ _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } HWY_DIAGNOSTICS(pop) // ================================================== COMPARE // Comparisons set a mask bit to 1 if the condition is true, else 0. template HWY_API Mask512 RebindMask(Full512 /*tag*/, Mask512 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask512{m.raw}; } namespace detail { template HWY_INLINE Mask512 TestBit(hwy::SizeTag<1> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask512 TestBit(hwy::SizeTag<2> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask512 TestBit(hwy::SizeTag<4> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask512 TestBit(hwy::SizeTag<8> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; } } // namespace detail template HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return detail::TestBit(hwy::SizeTag(), v, bit); } // ------------------------------ Equality template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } // ------------------------------ Inequality template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } // ------------------------------ Strict inequality HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; } // ------------------------------ Weak inequality HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; } HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; } // ------------------------------ Reversed comparisons template HWY_API Mask512 operator<(Vec512 a, Vec512 b) { return b > a; } template HWY_API Mask512 operator<=(Vec512 a, Vec512 b) { return b >= a; } // ------------------------------ Mask namespace detail { template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi8_mask(v.raw)}; } template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi16_mask(v.raw)}; } template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi32_mask(v.raw)}; } template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi64_mask(v.raw)}; } } // namespace detail template HWY_API Mask512 MaskFromVec(const Vec512 v) { return detail::MaskFromVec(hwy::SizeTag(), v); } // There do not seem to be native floating-point versions of these instructions. HWY_API Mask512 MaskFromVec(const Vec512 v) { return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; } HWY_API Mask512 MaskFromVec(const Vec512 v) { return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi8(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi8(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi16(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi16(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi32(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi32(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi64(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi64(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; } template HWY_API Vec512 VecFromMask(Full512 /* tag */, const Mask512 v) { return VecFromMask(v); } // ------------------------------ Mask logical namespace detail { template HWY_INLINE Mask512 Not(hwy::SizeTag<1> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask64(m.raw)}; #else return Mask512{~m.raw}; #endif } template HWY_INLINE Mask512 Not(hwy::SizeTag<2> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask32(m.raw)}; #else return Mask512{~m.raw}; #endif } template HWY_INLINE Mask512 Not(hwy::SizeTag<4> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask16(m.raw)}; #else return Mask512{static_cast(~m.raw & 0xFFFF)}; #endif } template HWY_INLINE Mask512 Not(hwy::SizeTag<8> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask8(m.raw)}; #else return Mask512{static_cast(~m.raw & 0xFF)}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask64(a.raw, b.raw)}; #else return Mask512{a.raw & b.raw}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask32(a.raw, b.raw)}; #else return Mask512{a.raw & b.raw}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask64(a.raw, b.raw)}; #else return Mask512{~a.raw & b.raw}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask32(a.raw, b.raw)}; #else return Mask512{~a.raw & b.raw}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask64(a.raw, b.raw)}; #else return Mask512{a.raw | b.raw}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask32(a.raw, b.raw)}; #else return Mask512{a.raw | b.raw}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw | b.raw)}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw | b.raw)}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask64(a.raw, b.raw)}; #else return Mask512{a.raw ^ b.raw}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask32(a.raw, b.raw)}; #else return Mask512{a.raw ^ b.raw}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxnor_mask64(a.raw, b.raw)}; #else return Mask512{~(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxnor_mask32(a.raw, b.raw)}; #else return Mask512{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; #endif } template HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxnor_mask16(a.raw, b.raw)}; #else return Mask512{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; #endif } template HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxnor_mask8(a.raw, b.raw)}; #else return Mask512{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; #endif } } // namespace detail template HWY_API Mask512 Not(const Mask512 m) { return detail::Not(hwy::SizeTag(), m); } template HWY_API Mask512 And(const Mask512 a, Mask512 b) { return detail::And(hwy::SizeTag(), a, b); } template HWY_API Mask512 AndNot(const Mask512 a, Mask512 b) { return detail::AndNot(hwy::SizeTag(), a, b); } template HWY_API Mask512 Or(const Mask512 a, Mask512 b) { return detail::Or(hwy::SizeTag(), a, b); } template HWY_API Mask512 Xor(const Mask512 a, Mask512 b) { return detail::Xor(hwy::SizeTag(), a, b); } template HWY_API Mask512 ExclusiveNeither(const Mask512 a, Mask512 b) { return detail::ExclusiveNeither(hwy::SizeTag(), a, b); } // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return VecFromMask(v < Zero(Full512())); } HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return ShiftRight<15>(v); } HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return ShiftRight<31>(v); } HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return Vec512{_mm512_srai_epi64(v.raw, 63)}; } // ------------------------------ Floating-point classification (Not) HWY_API Mask512 IsNaN(const Vec512 v) { return Mask512{_mm512_fpclass_ps_mask(v.raw, 0x81)}; } HWY_API Mask512 IsNaN(const Vec512 v) { return Mask512{_mm512_fpclass_pd_mask(v.raw, 0x81)}; } HWY_API Mask512 IsInf(const Vec512 v) { return Mask512{_mm512_fpclass_ps_mask(v.raw, 0x18)}; } HWY_API Mask512 IsInf(const Vec512 v) { return Mask512{_mm512_fpclass_pd_mask(v.raw, 0x18)}; } // Returns whether normal/subnormal/zero. fpclass doesn't have a flag for // positive, so we have to check for inf/NaN and negate. HWY_API Mask512 IsFinite(const Vec512 v) { return Not(Mask512{_mm512_fpclass_ps_mask(v.raw, 0x99)}); } HWY_API Mask512 IsFinite(const Vec512 v) { return Not(Mask512{_mm512_fpclass_pd_mask(v.raw, 0x99)}); } // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec512 Load(Full512 /* tag */, const T* HWY_RESTRICT aligned) { return Vec512{_mm512_load_si512(aligned)}; } HWY_API Vec512 Load(Full512 /* tag */, const float* HWY_RESTRICT aligned) { return Vec512{_mm512_load_ps(aligned)}; } HWY_API Vec512 Load(Full512 /* tag */, const double* HWY_RESTRICT aligned) { return Vec512{_mm512_load_pd(aligned)}; } template HWY_API Vec512 LoadU(Full512 /* tag */, const T* HWY_RESTRICT p) { return Vec512{_mm512_loadu_si512(p)}; } HWY_API Vec512 LoadU(Full512 /* tag */, const float* HWY_RESTRICT p) { return Vec512{_mm512_loadu_ps(p)}; } HWY_API Vec512 LoadU(Full512 /* tag */, const double* HWY_RESTRICT p) { return Vec512{_mm512_loadu_pd(p)}; } // ------------------------------ MaskedLoad template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT p) { return Vec512{_mm512_maskz_loadu_epi8(m.raw, p)}; } template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT p) { return Vec512{_mm512_maskz_loadu_epi16(m.raw, p)}; } template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT p) { return Vec512{_mm512_maskz_loadu_epi32(m.raw, p)}; } template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT p) { return Vec512{_mm512_maskz_loadu_epi64(m.raw, p)}; } HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const float* HWY_RESTRICT p) { return Vec512{_mm512_maskz_loadu_ps(m.raw, p)}; } HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const double* HWY_RESTRICT p) { return Vec512{_mm512_maskz_loadu_pd(m.raw, p)}; } // ------------------------------ LoadDup128 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the // 3-cycle cost of moving data between 128-bit halves and avoids port 5. template HWY_API Vec512 LoadDup128(Full512 /* tag */, const T* const HWY_RESTRICT p) { const auto x4 = LoadU(Full128(), p); return Vec512{_mm512_broadcast_i32x4(x4.raw)}; } HWY_API Vec512 LoadDup128(Full512 /* tag */, const float* const HWY_RESTRICT p) { const __m128 x4 = _mm_loadu_ps(p); return Vec512{_mm512_broadcast_f32x4(x4)}; } HWY_API Vec512 LoadDup128(Full512 /* tag */, const double* const HWY_RESTRICT p) { const __m128d x2 = _mm_loadu_pd(p); return Vec512{_mm512_broadcast_f64x2(x2)}; } // ------------------------------ Store template HWY_API void Store(const Vec512 v, Full512 /* tag */, T* HWY_RESTRICT aligned) { _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); } HWY_API void Store(const Vec512 v, Full512 /* tag */, float* HWY_RESTRICT aligned) { _mm512_store_ps(aligned, v.raw); } HWY_API void Store(const Vec512 v, Full512 /* tag */, double* HWY_RESTRICT aligned) { _mm512_store_pd(aligned, v.raw); } template HWY_API void StoreU(const Vec512 v, Full512 /* tag */, T* HWY_RESTRICT p) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); } HWY_API void StoreU(const Vec512 v, Full512 /* tag */, float* HWY_RESTRICT p) { _mm512_storeu_ps(p, v.raw); } HWY_API void StoreU(const Vec512 v, Full512, double* HWY_RESTRICT p) { _mm512_storeu_pd(p, v.raw); } // ------------------------------ BlendedStore template HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, T* HWY_RESTRICT p) { _mm512_mask_storeu_epi8(p, m.raw, v.raw); } template HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, T* HWY_RESTRICT p) { _mm512_mask_storeu_epi16(p, m.raw, v.raw); } template HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, T* HWY_RESTRICT p) { _mm512_mask_storeu_epi32(p, m.raw, v.raw); } template HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, T* HWY_RESTRICT p) { _mm512_mask_storeu_epi64(p, m.raw, v.raw); } HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, float* HWY_RESTRICT p) { _mm512_mask_storeu_ps(p, m.raw, v.raw); } HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, double* HWY_RESTRICT p) { _mm512_mask_storeu_pd(p, m.raw, v.raw); } // ------------------------------ Non-temporal stores template HWY_API void Stream(const Vec512 v, Full512 /* tag */, T* HWY_RESTRICT aligned) { _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); } HWY_API void Stream(const Vec512 v, Full512 /* tag */, float* HWY_RESTRICT aligned) { _mm512_stream_ps(aligned, v.raw); } HWY_API void Stream(const Vec512 v, Full512, double* HWY_RESTRICT aligned) { _mm512_stream_pd(aligned, v.raw); } // ------------------------------ Scatter // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") namespace detail { template HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 offset) { _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 index) { _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); } template HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 offset) { _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 index) { _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); } } // namespace detail template HWY_API void ScatterOffset(Vec512 v, Full512 d, T* HWY_RESTRICT base, const Vec512 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); } template HWY_API void ScatterIndex(Vec512 v, Full512 d, T* HWY_RESTRICT base, const Vec512 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); } HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, float* HWY_RESTRICT base, const Vec512 offset) { _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, float* HWY_RESTRICT base, const Vec512 index) { _mm512_i32scatter_ps(base, index.raw, v.raw, 4); } HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, double* HWY_RESTRICT base, const Vec512 offset) { _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, double* HWY_RESTRICT base, const Vec512 index) { _mm512_i64scatter_pd(base, index.raw, v.raw, 8); } // ------------------------------ Gather namespace detail { template HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<4> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i32gather_epi32(offset.raw, base, 1)}; } template HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<4> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i32gather_epi32(index.raw, base, 4)}; } template HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<8> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i64gather_epi64(offset.raw, base, 1)}; } template HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<8> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i64gather_epi64(index.raw, base, 8)}; } } // namespace detail template HWY_API Vec512 GatherOffset(Full512 d, const T* HWY_RESTRICT base, const Vec512 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::GatherOffset(hwy::SizeTag(), d, base, offset); } template HWY_API Vec512 GatherIndex(Full512 d, const T* HWY_RESTRICT base, const Vec512 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::GatherIndex(hwy::SizeTag(), d, base, index); } HWY_API Vec512 GatherOffset(Full512 /* tag */, const float* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i32gather_ps(offset.raw, base, 1)}; } HWY_API Vec512 GatherIndex(Full512 /* tag */, const float* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i32gather_ps(index.raw, base, 4)}; } HWY_API Vec512 GatherOffset(Full512 /* tag */, const double* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i64gather_pd(offset.raw, base, 1)}; } HWY_API Vec512 GatherIndex(Full512 /* tag */, const double* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i64gather_pd(index.raw, base, 8)}; } HWY_DIAGNOSTICS(pop) // ================================================== SWIZZLE // ------------------------------ LowerHalf template HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_castsi512_si256(v.raw)}; } HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_castps512_ps256(v.raw)}; } HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_castpd512_pd256(v.raw)}; } template HWY_API Vec256 LowerHalf(Vec512 v) { return LowerHalf(Full256(), v); } // ------------------------------ UpperHalf template HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_extracti32x8_epi32(v.raw, 1)}; } HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_extractf32x8_ps(v.raw, 1)}; } HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_extractf64x4_pd(v.raw, 1)}; } // ------------------------------ ExtractLane (Store) template HWY_API T ExtractLane(const Vec512 v, size_t i) { const Full512 d; HWY_DASSERT(i < Lanes(d)); alignas(64) T lanes[64 / sizeof(T)]; Store(v, d, lanes); return lanes[i]; } // ------------------------------ InsertLane (Store) template HWY_API Vec512 InsertLane(const Vec512 v, size_t i, T t) { const Full512 d; HWY_DASSERT(i < Lanes(d)); alignas(64) T lanes[64 / sizeof(T)]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } // ------------------------------ GetLane (LowerHalf) template HWY_API T GetLane(const Vec512 v) { return GetLane(LowerHalf(v)); } // ------------------------------ ZeroExtendVector template HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { #if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. return Vec512{_mm512_zextsi256_si512(lo.raw)}; #else return Vec512{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; #endif } HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { #if HWY_HAVE_ZEXT return Vec512{_mm512_zextps256_ps512(lo.raw)}; #else return Vec512{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; #endif } HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { #if HWY_HAVE_ZEXT return Vec512{_mm512_zextpd256_pd512(lo.raw)}; #else return Vec512{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; #endif } // ------------------------------ Combine template HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { const auto lo512 = ZeroExtendVector(d, lo); return Vec512{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; } HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { const auto lo512 = ZeroExtendVector(d, lo); return Vec512{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; } HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { const auto lo512 = ZeroExtendVector(d, lo); return Vec512{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; } // ------------------------------ ShiftLeftBytes template HWY_API Vec512 ShiftLeftBytes(Full512 /* tag */, const Vec512 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); return Vec512{_mm512_bslli_epi128(v.raw, kBytes)}; } template HWY_API Vec512 ShiftLeftBytes(const Vec512 v) { return ShiftLeftBytes(Full512(), v); } // ------------------------------ ShiftLeftLanes template HWY_API Vec512 ShiftLeftLanes(Full512 d, const Vec512 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec512 ShiftLeftLanes(const Vec512 v) { return ShiftLeftLanes(Full512(), v); } // ------------------------------ ShiftRightBytes template HWY_API Vec512 ShiftRightBytes(Full512 /* tag */, const Vec512 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); return Vec512{_mm512_bsrli_epi128(v.raw, kBytes)}; } // ------------------------------ ShiftRightLanes template HWY_API Vec512 ShiftRightLanes(Full512 d, const Vec512 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ CombineShiftRightBytes template > HWY_API V CombineShiftRightBytes(Full512 d, V hi, V lo) { const Repartition d8; return BitCast(d, Vec512{_mm512_alignr_epi8( BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); } // ------------------------------ Broadcast/splat any lane // Unsigned template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec512{_mm512_unpacklo_epi64(lo, lo)}; } else { const __m512i hi = _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec512{_mm512_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } // Signed template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec512{_mm512_unpacklo_epi64(lo, lo)}; } else { const __m512i hi = _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec512{_mm512_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } // Float template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; } // ------------------------------ Hard-coded shuffles // Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is // least-significant). Shuffle0321 rotates four-lane blocks one lane to the // right (the previous least-significant lane is now most-significant => // 47650321). These could also be implemented via CombineShiftRightBytes but // the shuffle_abcd notation is more convenient. // Swap 32-bit halves in 64-bit halves. template HWY_API Vec512 Shuffle2301(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 Shuffle2301(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; } namespace detail { template HWY_API Vec512 Shuffle2301(const Vec512 a, const Vec512 b) { const Full512 d; const RebindToFloat df; return BitCast( d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, _MM_PERM_CDAB)}); } template HWY_API Vec512 Shuffle1230(const Vec512 a, const Vec512 b) { const Full512 d; const RebindToFloat df; return BitCast( d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, _MM_PERM_BCDA)}); } template HWY_API Vec512 Shuffle3012(const Vec512 a, const Vec512 b) { const Full512 d; const RebindToFloat df; return BitCast( d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, _MM_PERM_DABC)}); } } // namespace detail // Swap 64-bit halves HWY_API Vec512 Shuffle1032(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle1032(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle1032(const Vec512 v) { // Shorter encoding than _mm512_permute_ps. return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle01(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle01(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle01(const Vec512 v) { // Shorter encoding than _mm512_permute_pd. return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; } // Rotate right 32 bits HWY_API Vec512 Shuffle0321(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; } HWY_API Vec512 Shuffle0321(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; } HWY_API Vec512 Shuffle0321(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; } // Rotate left 32 bits HWY_API Vec512 Shuffle2103(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; } HWY_API Vec512 Shuffle2103(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; } HWY_API Vec512 Shuffle2103(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; } // Reverse HWY_API Vec512 Shuffle0123(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 Shuffle0123(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 Shuffle0123(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. template struct Indices512 { __m512i raw; }; template HWY_API Indices512 IndicesFromVec(Full512 /* tag */, Vec512 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD const Full512 di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(64 / sizeof(T)))))); #endif return Indices512{vec.raw}; } template HWY_API Indices512 SetTableIndices(const Full512 d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; } template HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_epi64(idx.raw, v.raw)}; } HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; } HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_pd(idx.raw, v.raw)}; } // ------------------------------ Reverse template HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int16_t kReverse[32] = { 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; const Vec512 idx = Load(di, kReverse); return BitCast(d, Vec512{ _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { alignas(64) constexpr int32_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } template HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { alignas(64) constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } // ------------------------------ Reverse2 template HWY_API Vec512 Reverse2(Full512 d, const Vec512 v) { const Full512 du32; return BitCast(d, RotateRight<16>(BitCast(du32, v))); } template HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { return Shuffle2301(v); } template HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { return Shuffle01(v); } // ------------------------------ Reverse4 template HWY_API Vec512 Reverse4(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int16_t kReverse4[32] = { 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; const Vec512 idx = Load(di, kReverse4); return BitCast(d, Vec512{ _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { return Shuffle0123(v); } template HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { return Vec512{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; } HWY_API Vec512 Reverse4(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; } // ------------------------------ Reverse8 template HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int16_t kReverse8[32] = { 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; const Vec512 idx = Load(di, kReverse8); return BitCast(d, Vec512{ _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int32_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; const Vec512 idx = Load(di, kReverse8); return BitCast(d, Vec512{ _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { return Reverse(d, v); } // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides // the least-significant lane) and "b". To concatenate two half-width integers // into one, use ZipLower/Upper instead (also works with scalar). HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; } // ------------------------------ InterleaveUpper // All functions inside detail lack the required D parameter. namespace detail { HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_ps(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_pd(a.raw, b.raw)}; } } // namespace detail template > HWY_API V InterleaveUpper(Full512 /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template > HWY_API Vec512 ZipLower(Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveLower(a, b)); } template > HWY_API Vec512 ZipLower(Full512 /* d */, Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveLower(a, b)); } template > HWY_API Vec512 ZipUpper(Full512 d, Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveUpper(d, a, b)); } // ------------------------------ Concat* halves // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; } HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; } HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; } // hiH,hiL loH,loL |-> hiH,loH (= upper halves) template HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; } HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; } HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; } // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) template HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; } HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; } HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; } // hiH,hiL loH,loL |-> hiH,loL (= outer halves) template HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks // are efficiently loaded from 32-bit regs. const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); return Vec512{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; } HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); return Vec512{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; } HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; } // ------------------------------ ConcatOdd template HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; #if HWY_TARGET == HWY_AVX3_DL alignas(64) constexpr uint8_t kIdx[64] = { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi8( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); #else const RepartitionToWide dw; // Right-shift 8 bits per u16 so we can pack. const Vec512 uH = ShiftRight<8>(BitCast(dw, hi)); const Vec512 uL = ShiftRight<8>(BitCast(dw, lo)); const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. const Full512 du64; alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); #endif } template HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint16_t kIdx[32] = { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi16( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); } template HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, hi.raw)}; } template HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; } // ------------------------------ ConcatEven template HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; #if HWY_TARGET == HWY_AVX3_DL alignas(64) constexpr uint8_t kIdx[64] = { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi8( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); #else const RepartitionToWide dw; // Isolate lower 8 bits per u16 so we can pack. const Vec512 mask = Set(dw, 0x00FF); const Vec512 uH = And(BitCast(dw, hi), mask); const Vec512 uL = And(BitCast(dw, lo), mask); const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. const Full512 du64; alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); #endif } template HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint16_t kIdx[32] = { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi16( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); } template HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, hi.raw)}; } template HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; } // ------------------------------ DupEven (InterleaveLower) template HWY_API Vec512 DupEven(Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; } HWY_API Vec512 DupEven(Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; } template HWY_API Vec512 DupEven(const Vec512 v) { return InterleaveLower(Full512(), v, v); } // ------------------------------ DupOdd (InterleaveUpper) template HWY_API Vec512 DupOdd(Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; } HWY_API Vec512 DupOdd(Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; } template HWY_API Vec512 DupOdd(const Vec512 v) { return InterleaveUpper(Full512(), v, v); } // ------------------------------ OddEven template HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { constexpr size_t s = sizeof(T); constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); } // ------------------------------ OddEvenBlocks template HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { return Vec512{_mm512_mask_blend_epi64(__mmask8{0x33u}, odd.raw, even.raw)}; } HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { return Vec512{ _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; } HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { return Vec512{ _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; } // ------------------------------ ReverseBlocks template HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; } // ------------------------------ TableLookupBytes (ZeroExtendVector) // Both full template HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { return Vec512{_mm512_shuffle_epi8(bytes.raw, indices.raw)}; } // Partial index vector template HWY_API Vec128 TableLookupBytes(Vec512 bytes, Vec128 from) { const Full512 d512; const Half d256; const Half d128; // First expand to full 128, then 256, then 512. const Vec128 from_full{from.raw}; const auto from_512 = ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); const auto tbl_full = TableLookupBytes(bytes, from_512); // Shrink to 256, then 128, then partial. return Vec128{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; } template HWY_API Vec256 TableLookupBytes(Vec512 bytes, Vec256 from) { const auto from_512 = ZeroExtendVector(Full512(), from); return LowerHalf(Full256(), TableLookupBytes(bytes, from_512)); } // Partial table vector template HWY_API Vec512 TableLookupBytes(Vec128 bytes, Vec512 from) { const Full512 d512; const Half d256; const Half d128; // First expand to full 128, then 256, then 512. const Vec128 bytes_full{bytes.raw}; const auto bytes_512 = ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); return TableLookupBytes(bytes_512, from); } template HWY_API Vec512 TableLookupBytes(Vec256 bytes, Vec512 from) { const auto bytes_512 = ZeroExtendVector(Full512(), bytes); return TableLookupBytes(bytes_512, from); } // Partial both are handled by x86_128/256. // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) // Unsigned: zero-extend. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu8_epi16(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec128 v) { return Vec512{_mm512_cvtepu8_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu8_epi16(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec128 v) { return Vec512{_mm512_cvtepu8_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu16_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu16_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu32_epi64(v.raw)}; } // Signed: replicate sign bit. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by // signed shift would be faster. HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi8_epi16(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec128 v) { return Vec512{_mm512_cvtepi8_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi16_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi32_epi64(v.raw)}; } // Float HWY_API Vec512 PromoteTo(Full512 /* tag */, const Vec256 v) { return Vec512{_mm512_cvtph_ps(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 df32, const Vec256 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtps_pd(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi32_pd(v.raw)}; } // ------------------------------ Demotions (full -> part w/ narrow lanes) HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; return LowerHalf(even); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; return LowerHalf(even); } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec512 v) { const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; // packus treats the input as signed; we want unsigned. Clear the MSB to get // unsigned saturation to u8. const Vec512 i16{ _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; const auto idx32 = LoadDup128(Full512(), kLanes); const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; return LowerHalf(LowerHalf(fixed)); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; return LowerHalf(even); } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec512 v) { const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12}; const auto idx32 = LoadDup128(Full512(), kLanes); const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; return LowerHalf(LowerHalf(fixed)); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; return LowerHalf(even); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Vec256{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; HWY_DIAGNOSTICS(pop) } HWY_API Vec256 DemoteTo(Full256 dbf16, const Vec512 v) { // TODO(janwas): _mm512_cvtneps_pbh once we have avx512bf16. const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } HWY_API Vec512 ReorderDemote2To(Full512 dbf16, Vec512 a, Vec512 b) { // TODO(janwas): _mm512_cvtne2ps_pbh once we have avx512bf16. const RebindToUnsigned du16; const Repartition du32; const Vec512 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } HWY_API Vec512 ReorderDemote2To(Full512 /*d16*/, Vec512 a, Vec512 b) { return Vec512{_mm512_packs_epi32(a.raw, b.raw)}; } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { return Vec256{_mm512_cvtpd_ps(v.raw)}; } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const auto clamped = detail::ClampF64ToI32Max(Full512(), v); return Vec256{_mm512_cvttpd_epi32(clamped.raw)}; } // For already range-limited input [0, 255]. HWY_API Vec128 U8FromU32(const Vec512 v) { const Full512 d32; // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the // lowest 4 bytes. alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, ~0u}; const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); // Gather the lowest 4 bytes of 4 128-bit blocks. alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; const Vec512 bytes{ _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; return LowerHalf(LowerHalf(bytes)); } // ------------------------------ Truncations HWY_API Vec128 TruncateTo(Simd d, const Vec512 v) { #if HWY_TARGET == HWY_AVX3_DL (void)d; const Full512 d8; alignas(16) static constexpr uint8_t k8From64[16] = { 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56}; const Vec512 bytes{ _mm512_permutexvar_epi8(LoadDup128(d8, k8From64).raw, v.raw)}; return LowerHalf(LowerHalf(LowerHalf(bytes))); #else const Full512 d32; alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14}; const Vec512 even{ _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; return TruncateTo(d, LowerHalf(even)); #endif } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec512 v) { const Full512 d16; alignas(16) static constexpr uint16_t k16From64[8] = { 0, 4, 8, 12, 16, 20, 24, 28}; const Vec512 bytes{ _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; return LowerHalf(LowerHalf(bytes)); } HWY_API Vec256 TruncateTo(Simd /* tag */, const Vec512 v) { const Full512 d32; alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14}; const Vec512 even{ _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; return LowerHalf(even); } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec512 v) { #if HWY_TARGET == HWY_AVX3_DL const Full512 d8; alignas(16) static constexpr uint8_t k8From32[16] = { 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; const Vec512 bytes{ _mm512_permutexvar_epi32(LoadDup128(d8, k8From32).raw, v.raw)}; #else const Full512 d32; // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the // lowest 4 bytes. alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, ~0u}; const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); // Gather the lowest 4 bytes of 4 128-bit blocks. alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; const Vec512 bytes{ _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; #endif return LowerHalf(LowerHalf(bytes)); } HWY_API Vec256 TruncateTo(Simd /* tag */, const Vec512 v) { const Full512 d16; alignas(64) static constexpr uint16_t k16From32[32] = { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; const Vec512 bytes{ _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; return LowerHalf(bytes); } HWY_API Vec256 TruncateTo(Simd /* tag */, const Vec512 v) { #if HWY_TARGET == HWY_AVX3_DL const Full512 d8; alignas(64) static constexpr uint8_t k8From16[64] = { 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; const Vec512 bytes{ _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; #else const Full512 d32; alignas(16) static constexpr uint32_t k16From32[4] = { 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u}; const auto quads = TableLookupBytes(v, LoadDup128(d32, k16From32)); alignas(64) static constexpr uint32_t kIndex32[16] = { 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; const Vec512 bytes{ _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; #endif return LowerHalf(bytes); } // ------------------------------ Convert integer <=> floating point HWY_API Vec512 ConvertTo(Full512 /* tag */, const Vec512 v) { return Vec512{_mm512_cvtepi32_ps(v.raw)}; } HWY_API Vec512 ConvertTo(Full512 /* tag */, const Vec512 v) { return Vec512{_mm512_cvtepi64_pd(v.raw)}; } HWY_API Vec512 ConvertTo(Full512 /* tag*/, const Vec512 v) { return Vec512{_mm512_cvtepu32_ps(v.raw)}; } HWY_API Vec512 ConvertTo(Full512 /* tag*/, const Vec512 v) { return Vec512{_mm512_cvtepu64_pd(v.raw)}; } // Truncates (rounds toward zero). HWY_API Vec512 ConvertTo(Full512 d, const Vec512 v) { return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); } HWY_API Vec512 ConvertTo(Full512 di, const Vec512 v) { return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); } HWY_API Vec512 NearestInt(const Vec512 v) { const Full512 di; return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); } // ================================================== CRYPTO #if !defined(HWY_DISABLE_PCLMUL_AES) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES #undef HWY_NATIVE_AES #else #define HWY_NATIVE_AES #endif HWY_API Vec512 AESRound(Vec512 state, Vec512 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; #else const Full512 d; const Half d2; return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec512 AESLastRound(Vec512 state, Vec512 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; #else const Full512 d; const Half d2; return Combine(d, AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESLastRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec512 CLMulLower(Vec512 va, Vec512 vb) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; #else alignas(64) uint64_t a[8]; alignas(64) uint64_t b[8]; const Full512 d; const Full128 d128; Store(va, d, a); Store(vb, d, b); for (size_t i = 0; i < 8; i += 2) { const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); Store(mul, d128, a + i); } return Load(d, a); #endif } HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; #else alignas(64) uint64_t a[8]; alignas(64) uint64_t b[8]; const Full512 d; const Full128 d128; Store(va, d, a); Store(vb, d, b); for (size_t i = 0; i < 8; i += 2) { const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); Store(mul, d128, a + i); } return Load(d, a); #endif } #endif // HWY_DISABLE_PCLMUL_AES // ================================================== MISC // Returns a vector with lane i=[0, N) set to "first" + i. template Vec512 Iota(const Full512 d, const T2 first) { HWY_ALIGN T lanes[64 / sizeof(T)]; for (size_t i = 0; i < 64 / sizeof(T); ++i) { lanes[i] = AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); } return Load(d, lanes); } // ------------------------------ Mask testing // Beware: the suffix indicates the number of mask bits, not lane size! namespace detail { template HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask64_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } } // namespace detail template HWY_API bool AllFalse(const Full512 /* tag */, const Mask512 mask) { return detail::AllFalse(hwy::SizeTag(), mask); } namespace detail { template HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask64_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFFFFFFFFFFFFFull; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFFFFFull; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFull; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFull; #endif } } // namespace detail template HWY_API bool AllTrue(const Full512 /* tag */, const Mask512 mask) { return detail::AllTrue(hwy::SizeTag(), mask); } // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask512 LoadMaskBits(const Full512 /* tag */, const uint8_t* HWY_RESTRICT bits) { Mask512 mask; CopyBytes<8 / sizeof(T)>(bits, &mask.raw); // N >= 8 (= 512 / 64), so no need to mask invalid bits. return mask; } // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full512 /* tag */, const Mask512 mask, uint8_t* bits) { const size_t kNumBytes = 8 / sizeof(T); CopyBytes(&mask.raw, bits); // N >= 8 (= 512 / 64), so no need to mask invalid bits. return kNumBytes; } template HWY_API size_t CountTrue(const Full512 /* tag */, const Mask512 mask) { return PopCount(static_cast(mask.raw)); } template HWY_API size_t FindKnownFirstTrue(const Full512 /* tag */, const Mask512 mask) { return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); } template HWY_API size_t FindKnownFirstTrue(const Full512 /* tag */, const Mask512 mask) { return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); } template HWY_API intptr_t FindFirstTrue(const Full512 d, const Mask512 mask) { return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) : intptr_t{-1}; } // ------------------------------ Compress // Always implement 8-bit here even if we lack VBMI2 because we can do better // than generic_ops (8 at a time) via the native 32-bit compress (16 at a time). #ifdef HWY_NATIVE_COMPRESS8 #undef HWY_NATIVE_COMPRESS8 #else #define HWY_NATIVE_COMPRESS8 #endif namespace detail { #if HWY_TARGET == HWY_AVX3_DL // VBMI2 template HWY_INLINE Vec128 NativeCompress(const Vec128 v, const Mask128 mask) { return Vec128{_mm_maskz_compress_epi8(mask.raw, v.raw)}; } HWY_INLINE Vec256 NativeCompress(const Vec256 v, const Mask256 mask) { return Vec256{_mm256_maskz_compress_epi8(mask.raw, v.raw)}; } HWY_INLINE Vec512 NativeCompress(const Vec512 v, const Mask512 mask) { return Vec512{_mm512_maskz_compress_epi8(mask.raw, v.raw)}; } template HWY_INLINE Vec128 NativeCompress(const Vec128 v, const Mask128 mask) { return Vec128{_mm_maskz_compress_epi16(mask.raw, v.raw)}; } HWY_INLINE Vec256 NativeCompress(const Vec256 v, const Mask256 mask) { return Vec256{_mm256_maskz_compress_epi16(mask.raw, v.raw)}; } HWY_INLINE Vec512 NativeCompress(const Vec512 v, const Mask512 mask) { return Vec512{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; } template HWY_INLINE void NativeCompressStore(Vec128 v, Mask128 mask, Simd /* d */, uint8_t* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, Full256 /* d */, uint8_t* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, Full512 /* d */, uint8_t* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); } template HWY_INLINE void NativeCompressStore(Vec128 v, Mask128 mask, Simd /* d */, uint16_t* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, Full256 /* d */, uint16_t* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, Full512 /* d */, uint16_t* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); } #endif // HWY_TARGET == HWY_AVX3_DL template HWY_INLINE Vec128 NativeCompress(const Vec128 v, const Mask128 mask) { return Vec128{_mm_maskz_compress_epi32(mask.raw, v.raw)}; } HWY_INLINE Vec256 NativeCompress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; } HWY_INLINE Vec512 NativeCompress(Vec512 v, Mask512 mask) { return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; } // We use table-based compress for 64-bit lanes, see CompressIsPartition. template HWY_INLINE void NativeCompressStore(Vec128 v, Mask128 mask, Simd /* d */, uint32_t* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, Full256 /* d */, uint32_t* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, Full512 /* d */, uint32_t* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); } template HWY_INLINE void NativeCompressStore(Vec128 v, Mask128 mask, Simd /* d */, uint64_t* HWY_RESTRICT unaligned) { _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, Full256 /* d */, uint64_t* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); } HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, Full512 /* d */, uint64_t* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); } // For u8x16 and <= u16x16 we can avoid store+load for Compress because there is // only a single compressed vector (u32x16). Other EmuCompress are implemented // after the EmuCompressStore they build upon. template HWY_INLINE Vec128 EmuCompress(Vec128 v, Mask128 mask) { const Simd d; const Rebind d32; const auto v0 = PromoteTo(d32, v); const uint64_t mask_bits{mask.raw}; // Mask type is __mmask16 if v is full 128, else __mmask8. using M32 = MFromD; const M32 m0{static_cast(mask_bits)}; return TruncateTo(d, Compress(v0, m0)); } template HWY_INLINE Vec128 EmuCompress(Vec128 v, Mask128 mask) { const Simd d; const Rebind di32; const RebindToUnsigned du32; const MFromD mask32{static_cast<__mmask8>(mask.raw)}; // DemoteTo is 2 ops, but likely lower latency than TruncateTo on SKX. // Only i32 -> u16 is supported, whereas NativeCompress expects u32. const VFromD v32 = BitCast(du32, PromoteTo(di32, v)); return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); } HWY_INLINE Vec256 EmuCompress(Vec256 v, Mask256 mask) { const Full256 d; const Rebind di32; const RebindToUnsigned du32; const Mask512 mask32{static_cast<__mmask16>(mask.raw)}; const Vec512 v32 = BitCast(du32, PromoteTo(di32, v)); return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); } // See above - small-vector EmuCompressStore are implemented via EmuCompress. template HWY_INLINE void EmuCompressStore(Vec128 v, Mask128 mask, Simd d, T* HWY_RESTRICT unaligned) { StoreU(EmuCompress(v, mask), d, unaligned); } HWY_INLINE void EmuCompressStore(Vec256 v, Mask256 mask, Full256 d, uint16_t* HWY_RESTRICT unaligned) { StoreU(EmuCompress(v, mask), d, unaligned); } // Main emulation logic for wider vector, starting with EmuCompressStore because // it is most convenient to merge pieces using memory (concatenating vectors at // byte offsets is difficult). HWY_INLINE void EmuCompressStore(Vec256 v, Mask256 mask, Full256 d, uint8_t* HWY_RESTRICT unaligned) { const uint64_t mask_bits{mask.raw}; const Half dh; const Rebind d32; const Vec512 v0 = PromoteTo(d32, LowerHalf(v)); const Vec512 v1 = PromoteTo(d32, UpperHalf(dh, v)); const Mask512 m0{static_cast<__mmask16>(mask_bits & 0xFFFFu)}; const Mask512 m1{static_cast<__mmask16>(mask_bits >> 16)}; const Vec128 c0 = TruncateTo(dh, NativeCompress(v0, m0)); const Vec128 c1 = TruncateTo(dh, NativeCompress(v1, m1)); uint8_t* HWY_RESTRICT pos = unaligned; StoreU(c0, dh, pos); StoreU(c1, dh, pos + CountTrue(d32, m0)); } HWY_INLINE void EmuCompressStore(Vec512 v, Mask512 mask, Full512 d, uint8_t* HWY_RESTRICT unaligned) { const uint64_t mask_bits{mask.raw}; const Half> dq; const Rebind d32; HWY_ALIGN uint8_t lanes[64]; Store(v, d, lanes); const Vec512 v0 = PromoteTo(d32, LowerHalf(LowerHalf(v))); const Vec512 v1 = PromoteTo(d32, Load(dq, lanes + 16)); const Vec512 v2 = PromoteTo(d32, Load(dq, lanes + 32)); const Vec512 v3 = PromoteTo(d32, Load(dq, lanes + 48)); const Mask512 m0{static_cast<__mmask16>(mask_bits & 0xFFFFu)}; const Mask512 m1{ static_cast((mask_bits >> 16) & 0xFFFFu)}; const Mask512 m2{ static_cast((mask_bits >> 32) & 0xFFFFu)}; const Mask512 m3{static_cast<__mmask16>(mask_bits >> 48)}; const Vec128 c0 = TruncateTo(dq, NativeCompress(v0, m0)); const Vec128 c1 = TruncateTo(dq, NativeCompress(v1, m1)); const Vec128 c2 = TruncateTo(dq, NativeCompress(v2, m2)); const Vec128 c3 = TruncateTo(dq, NativeCompress(v3, m3)); uint8_t* HWY_RESTRICT pos = unaligned; StoreU(c0, dq, pos); pos += CountTrue(d32, m0); StoreU(c1, dq, pos); pos += CountTrue(d32, m1); StoreU(c2, dq, pos); pos += CountTrue(d32, m2); StoreU(c3, dq, pos); } HWY_INLINE void EmuCompressStore(Vec512 v, Mask512 mask, Full512 d, uint16_t* HWY_RESTRICT unaligned) { const Repartition di32; const RebindToUnsigned du32; const Half dh; const Vec512 promoted0 = BitCast(du32, PromoteTo(di32, LowerHalf(dh, v))); const Vec512 promoted1 = BitCast(du32, PromoteTo(di32, UpperHalf(dh, v))); const uint64_t mask_bits{mask.raw}; const uint64_t maskL = mask_bits & 0xFFFF; const uint64_t maskH = mask_bits >> 16; const Mask512 mask0{static_cast<__mmask16>(maskL)}; const Mask512 mask1{static_cast<__mmask16>(maskH)}; const Vec512 compressed0 = NativeCompress(promoted0, mask0); const Vec512 compressed1 = NativeCompress(promoted1, mask1); const Vec256 demoted0 = DemoteTo(dh, BitCast(di32, compressed0)); const Vec256 demoted1 = DemoteTo(dh, BitCast(di32, compressed1)); // Store 256-bit halves StoreU(demoted0, dh, unaligned); StoreU(demoted1, dh, unaligned + PopCount(maskL)); } // Finally, the remaining EmuCompress for wide vectors, using EmuCompressStore. template // 1 or 2 bytes HWY_INLINE Vec512 EmuCompress(Vec512 v, Mask512 mask) { const Full512 d; HWY_ALIGN T buf[2 * 64 / sizeof(T)]; EmuCompressStore(v, mask, d, buf); return Load(d, buf); } HWY_INLINE Vec256 EmuCompress(Vec256 v, const Mask256 mask) { const Full256 d; HWY_ALIGN uint8_t buf[2 * 32 / sizeof(uint8_t)]; EmuCompressStore(v, mask, d, buf); return Load(d, buf); } } // namespace detail template // 1 or 2 bytes HWY_API V Compress(V v, const M mask) { const DFromV d; const RebindToUnsigned du; const auto mu = RebindMask(du, mask); #if HWY_TARGET == HWY_AVX3_DL // VBMI2 return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); #else return BitCast(d, detail::EmuCompress(BitCast(du, v), mu)); #endif } template HWY_API V Compress(V v, const M mask) { const DFromV d; const RebindToUnsigned du; const auto mu = RebindMask(du, mask); return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); } template HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { // See CompressIsPartition. u64 is faster than u32. alignas(16) constexpr uint64_t packed_array[256] = { // From PrintCompress32x8Tables, without the FirstN extension (there is // no benefit to including them because 64-bit CompressStore is anyway // masked, but also no harm because TableLookupLanes ignores the MSB). 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, 0x10765432, 0x17654320, 0x07654321, 0x76543210}; // For lane i, shift the i-th 4-bit index down to bits [0, 3) - // _mm512_permutexvar_epi64 will ignore the upper bits. const Full512 d; const RebindToUnsigned du64; const auto packed = Set(du64, packed_array[mask.raw]); alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; return TableLookupLanes(v, indices); } // ------------------------------ CompressNot template HWY_API V CompressNot(V v, const M mask) { return Compress(v, Not(mask)); } template HWY_API Vec512 CompressNot(Vec512 v, Mask512 mask) { // See CompressIsPartition. u64 is faster than u32. alignas(16) constexpr uint64_t packed_array[256] = { // From PrintCompressNot32x8Tables, without the FirstN extension (there is // no benefit to including them because 64-bit CompressStore is anyway // masked, but also no harm because TableLookupLanes ignores the MSB). 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, 0x76543210, 0x76543201, 0x76543210, 0x76543210}; // For lane i, shift the i-th 4-bit index down to bits [0, 3) - // _mm512_permutexvar_epi64 will ignore the upper bits. const Full512 d; const RebindToUnsigned du64; const auto packed = Set(du64, packed_array[mask.raw]); alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; return TableLookupLanes(v, indices); } // uint64_t lanes. Only implement for 256 and 512-bit vectors because this is a // no-op for 128-bit. template 16)>* = nullptr> HWY_API V CompressBlocksNot(V v, M mask) { return CompressNot(v, mask); } // ------------------------------ CompressBits template HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { return Compress(v, LoadMaskBits(DFromV(), bits)); } // ------------------------------ CompressStore template // 1 or 2 bytes HWY_API size_t CompressStore(V v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { const RebindToUnsigned du; const auto mu = RebindMask(du, mask); auto pu = reinterpret_cast * HWY_RESTRICT>(unaligned); #if HWY_TARGET == HWY_AVX3_DL // VBMI2 detail::NativeCompressStore(BitCast(du, v), mu, du, pu); #else detail::EmuCompressStore(BitCast(du, v), mu, du, pu); #endif const size_t count = CountTrue(d, mask); detail::MaybeUnpoison(pu, count); return count; } template // 4 or 8 HWY_API size_t CompressStore(V v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { const RebindToUnsigned du; const auto mu = RebindMask(du, mask); using TU = TFromD; TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); detail::NativeCompressStore(BitCast(du, v), mu, du, pu); const size_t count = CountTrue(d, mask); detail::MaybeUnpoison(pu, count); return count; } // Additional overloads to avoid casting to uint32_t (delay?). HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, float* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); const size_t count = PopCount(uint64_t{mask.raw}); detail::MaybeUnpoison(unaligned, count); return count; } HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, double* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); const size_t count = PopCount(uint64_t{mask.raw}); detail::MaybeUnpoison(unaligned, count); return count; } // ------------------------------ CompressBlendedStore template > HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, T* HWY_RESTRICT unaligned) { // Native CompressStore already does the blending at no extra cost (latency // 11, rthroughput 2 - same as compress plus store). if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) > 2) { return CompressStore(v, m, d, unaligned); } else { const size_t count = CountTrue(d, m); BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); detail::MaybeUnpoison(unaligned, count); return count; } } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, D d, TFromD* HWY_RESTRICT unaligned) { return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); } // ------------------------------ LoadInterleaved4 // Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. namespace detail { // Type-safe wrapper. template <_MM_PERM_ENUM kPerm, typename T> Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { return Vec512{_mm512_shuffle_i64x2(lo.raw, hi.raw, kPerm)}; } template <_MM_PERM_ENUM kPerm> Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; } template <_MM_PERM_ENUM kPerm> Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; } // Input (128-bit blocks): // 3 2 1 0 (<- first block in unaligned) // 7 6 5 4 // b a 9 8 // Output: // 9 6 3 0 (LSB of A) // a 7 4 1 // b 8 5 2 template HWY_API void LoadTransposedBlocks3(Full512 d, const T* HWY_RESTRICT unaligned, Vec512& A, Vec512& B, Vec512& C) { constexpr size_t N = 64 / sizeof(T); const Vec512 v3210 = LoadU(d, unaligned + 0 * N); const Vec512 v7654 = LoadU(d, unaligned + 1 * N); const Vec512 vba98 = LoadU(d, unaligned + 2 * N); const Vec512 v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); const Vec512 va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); } // Input (128-bit blocks): // 3 2 1 0 (<- first block in unaligned) // 7 6 5 4 // b a 9 8 // f e d c // Output: // c 8 4 0 (LSB of A) // d 9 5 1 // e a 6 2 // f b 7 3 template HWY_API void LoadTransposedBlocks4(Full512 d, const T* HWY_RESTRICT unaligned, Vec512& A, Vec512& B, Vec512& C, Vec512& D) { constexpr size_t N = 64 / sizeof(T); const Vec512 v3210 = LoadU(d, unaligned + 0 * N); const Vec512 v7654 = LoadU(d, unaligned + 1 * N); const Vec512 vba98 = LoadU(d, unaligned + 2 * N); const Vec512 vfedc = LoadU(d, unaligned + 3 * N); const Vec512 v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); const Vec512 vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); const Vec512 v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); const Vec512 vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); A = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); B = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); C = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); D = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); } } // namespace detail // ------------------------------ StoreInterleaved2 // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. namespace detail { // Input (128-bit blocks): // 6 4 2 0 (LSB of i) // 7 5 3 1 // Output: // 3 2 1 0 // 7 6 5 4 template HWY_API void StoreTransposedBlocks2(const Vec512 i, const Vec512 j, const Full512 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 64 / sizeof(T); const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); const auto j1_i1_j0_i0 = detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); const auto j3_i3_j2_i2 = detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); } // Input (128-bit blocks): // 9 6 3 0 (LSB of i) // a 7 4 1 // b 8 5 2 // Output: // 3 2 1 0 // 7 6 5 4 // b a 9 8 template HWY_API void StoreTransposedBlocks3(const Vec512 i, const Vec512 j, const Vec512 k, Full512 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 64 / sizeof(T); const Vec512 j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); const Vec512 i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); const Vec512 j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); const Vec512 out0 = // i1 k0 j0 i0 detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); const Vec512 out1 = // j2 i2 k1 j1 detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); const Vec512 out2 = // k3 j3 i3 k2 detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); StoreU(out0, d, unaligned + 0 * N); StoreU(out1, d, unaligned + 1 * N); StoreU(out2, d, unaligned + 2 * N); } // Input (128-bit blocks): // c 8 4 0 (LSB of i) // d 9 5 1 // e a 6 2 // f b 7 3 // Output: // 3 2 1 0 // 7 6 5 4 // b a 9 8 // f e d c template HWY_API void StoreTransposedBlocks4(const Vec512 i, const Vec512 j, const Vec512 k, const Vec512 l, Full512 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 64 / sizeof(T); const Vec512 j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); const Vec512 l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); const Vec512 j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); const Vec512 l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); const Vec512 out0 = detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); const Vec512 out1 = detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); const Vec512 out2 = detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); const Vec512 out3 = detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); StoreU(out0, d, unaligned + 0 * N); StoreU(out1, d, unaligned + 1 * N); StoreU(out2, d, unaligned + 2 * N); StoreU(out3, d, unaligned + 3 * N); } } // namespace detail // ------------------------------ MulEven/Odd (Shuffle2301, InterleaveLower) HWY_INLINE Vec512 MulEven(const Vec512 a, const Vec512 b) { const Full512 du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need the lower 32 bits const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need // the even (lower 64 bits of every 128-bit block) results. See // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveLower(mulL, mulH); } HWY_INLINE Vec512 MulOdd(const Vec512 a, const Vec512 b) { const Full512 du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need bits [95:64] (= upper half of input) const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Same as above, but we're using the odd results (upper 64 bits per block). const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveUpper(du64, mulL, mulH); } // ------------------------------ ReorderWidenMulAccumulate HWY_API Vec512 ReorderWidenMulAccumulate(Full512 /*d32*/, Vec512 a, Vec512 b, const Vec512 sum0, Vec512& /*sum1*/) { return sum0 + Vec512{_mm512_madd_epi16(a.raw, b.raw)}; } HWY_API Vec512 RearrangeToOddPlusEven(const Vec512 sum0, Vec512 /*sum1*/) { return sum0; // invariant already holds } // ------------------------------ Reductions // Returns the sum in each lane. HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_epi32(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_epi64(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, static_cast(_mm512_reduce_add_epi32(v.raw))); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, static_cast(_mm512_reduce_add_epi64(v.raw))); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_ps(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_pd(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto sum = SumOfLanes(d32, even + odd); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto sum = SumOfLanes(d32, even + odd); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); } // Returns the minimum in each lane. HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epi32(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epi64(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epu32(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epu64(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_ps(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_pd(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } // Returns the maximum in each lane. HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epi32(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epi64(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epu32(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epu64(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_ps(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_pd(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); // Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - // the warning seems to be issued at the call site of intrinsics, i.e. our code. HWY_DIAGNOSTICS(pop)