// Copyright 2019 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 128-bit ARM64 NEON vectors and operations. // External include guard in highway.h - see comment there. // ARM NEON intrinsics are documented at: // https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] #include #include #include "hwy/ops/shared-inl.h" HWY_BEFORE_NAMESPACE(); // Must come after HWY_BEFORE_NAMESPACE so that the intrinsics are compiled with // the same target attribute as our code, see #834. HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") #include // NOLINT(build/include_order) HWY_DIAGNOSTICS(pop) // Must come after arm_neon.h. namespace hwy { namespace HWY_NAMESPACE { namespace detail { // for code folding and Raw128 // Macros used to define single and double function calls for multiple types // for full and half vectors. These macros are undefined at the end of the file. // HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. #define HWY_NEON_BUILD_TPL_1 #define HWY_NEON_BUILD_TPL_2 #define HWY_NEON_BUILD_TPL_3 // HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can // extend it to int32x4x2_t packs. #define HWY_NEON_BUILD_RET_1(type, size) Vec128 #define HWY_NEON_BUILD_RET_2(type, size) Vec128 #define HWY_NEON_BUILD_RET_3(type, size) Vec128 // HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. #define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a #define HWY_NEON_BUILD_PARAM_2(type, size) \ const Vec128 a, const Vec128 b #define HWY_NEON_BUILD_PARAM_3(type, size) \ const Vec128 a, const Vec128 b, \ const Vec128 c // HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying // function. #define HWY_NEON_BUILD_ARG_1 a.raw #define HWY_NEON_BUILD_ARG_2 a.raw, b.raw #define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw // We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after // the __VA_ARGS__ have been expanded. This allows "func" to be a macro on // itself like with some of the library "functions" such as vshlq_u8. For // example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as // "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. // Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro // expects two arguments. #define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) // Main macro definition that defines a single function for the given type and // size of vector, using the underlying (prefix##infix##suffix) function and // the template, return type, parameters and arguments defined by the "args" // parameters passed here (see HWY_NEON_BUILD_* macros defined before). #define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ HWY_API HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ } // The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function // called "name" using the set of neon functions starting with the given // "prefix" for all the variants of certain types, as specified next to each // macro. For example, the prefix "vsub" can be used to define the operator- // using args=2. // uint8_t #define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) // int8_t #define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) // uint16_t #define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) // int16_t #define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) // uint32_t #define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) // int32_t #define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) // uint64_t #define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) // int64_t #define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) // float #define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) // double #if HWY_ARCH_ARM_A64 #define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) #else #define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) #endif // float and double #define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) // Helper macros to define for more than one type. // uint8_t, uint16_t and uint32_t #define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) // int8_t, int16_t and int32_t #define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) // uint8_t, uint16_t, uint32_t and uint64_t #define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) // int8_t, int16_t, int32_t and int64_t #define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) // All int*_t and uint*_t up to 64 #define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) // All previous types. #define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) #define HWY_NEON_DEF_FUNCTION_UIF81632(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) // For eor3q, which is only defined for full vectors. #define HWY_NEON_DEF_FUNCTION_FULL_UI(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) // Emulation of some intrinsics on armv7. #if HWY_ARCH_ARM_V7 #define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] #define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] #define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] #define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] #define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] #define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] #define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] #define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] #define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] #define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] #define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] #define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] #define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] #define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] #define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] #define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] #define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] #define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] #define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] #define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] #define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] #define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] #define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] #define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] #define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] #define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] #define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] #define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] #define vzip1_s8(x, y) vzip_s8(x, y).val[0] #define vzip1_u8(x, y) vzip_u8(x, y).val[0] #define vzip1_s16(x, y) vzip_s16(x, y).val[0] #define vzip1_u16(x, y) vzip_u16(x, y).val[0] #define vzip1_f32(x, y) vzip_f32(x, y).val[0] #define vzip1_u32(x, y) vzip_u32(x, y).val[0] #define vzip1_s32(x, y) vzip_s32(x, y).val[0] #define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] #define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] #define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] #define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] #define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] #define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] #define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] #define vzip2_s8(x, y) vzip_s8(x, y).val[1] #define vzip2_u8(x, y) vzip_u8(x, y).val[1] #define vzip2_s16(x, y) vzip_s16(x, y).val[1] #define vzip2_u16(x, y) vzip_u16(x, y).val[1] #define vzip2_s32(x, y) vzip_s32(x, y).val[1] #define vzip2_u32(x, y) vzip_u32(x, y).val[1] #define vzip2_f32(x, y) vzip_f32(x, y).val[1] #define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] #define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] #define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] #define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] #define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] #define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] #define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] #endif // Wrappers over uint8x16x2_t etc. so we can define StoreInterleaved2 overloads // for all vector types, even those (bfloat16_t) where the underlying vector is // the same as others (uint16_t). template struct Tuple2; template struct Tuple3; template struct Tuple4; template <> struct Tuple2 { uint8x16x2_t raw; }; template struct Tuple2 { uint8x8x2_t raw; }; template <> struct Tuple2 { int8x16x2_t raw; }; template struct Tuple2 { int8x8x2_t raw; }; template <> struct Tuple2 { uint16x8x2_t raw; }; template struct Tuple2 { uint16x4x2_t raw; }; template <> struct Tuple2 { int16x8x2_t raw; }; template struct Tuple2 { int16x4x2_t raw; }; template <> struct Tuple2 { uint32x4x2_t raw; }; template struct Tuple2 { uint32x2x2_t raw; }; template <> struct Tuple2 { int32x4x2_t raw; }; template struct Tuple2 { int32x2x2_t raw; }; template <> struct Tuple2 { uint64x2x2_t raw; }; template struct Tuple2 { uint64x1x2_t raw; }; template <> struct Tuple2 { int64x2x2_t raw; }; template struct Tuple2 { int64x1x2_t raw; }; template <> struct Tuple2 { uint16x8x2_t raw; }; template struct Tuple2 { uint16x4x2_t raw; }; template <> struct Tuple2 { uint16x8x2_t raw; }; template struct Tuple2 { uint16x4x2_t raw; }; template <> struct Tuple2 { float32x4x2_t raw; }; template struct Tuple2 { float32x2x2_t raw; }; #if HWY_ARCH_ARM_A64 template <> struct Tuple2 { float64x2x2_t raw; }; template struct Tuple2 { float64x1x2_t raw; }; #endif // HWY_ARCH_ARM_A64 template <> struct Tuple3 { uint8x16x3_t raw; }; template struct Tuple3 { uint8x8x3_t raw; }; template <> struct Tuple3 { int8x16x3_t raw; }; template struct Tuple3 { int8x8x3_t raw; }; template <> struct Tuple3 { uint16x8x3_t raw; }; template struct Tuple3 { uint16x4x3_t raw; }; template <> struct Tuple3 { int16x8x3_t raw; }; template struct Tuple3 { int16x4x3_t raw; }; template <> struct Tuple3 { uint32x4x3_t raw; }; template struct Tuple3 { uint32x2x3_t raw; }; template <> struct Tuple3 { int32x4x3_t raw; }; template struct Tuple3 { int32x2x3_t raw; }; template <> struct Tuple3 { uint64x2x3_t raw; }; template struct Tuple3 { uint64x1x3_t raw; }; template <> struct Tuple3 { int64x2x3_t raw; }; template struct Tuple3 { int64x1x3_t raw; }; template <> struct Tuple3 { uint16x8x3_t raw; }; template struct Tuple3 { uint16x4x3_t raw; }; template <> struct Tuple3 { uint16x8x3_t raw; }; template struct Tuple3 { uint16x4x3_t raw; }; template <> struct Tuple3 { float32x4x3_t raw; }; template struct Tuple3 { float32x2x3_t raw; }; #if HWY_ARCH_ARM_A64 template <> struct Tuple3 { float64x2x3_t raw; }; template struct Tuple3 { float64x1x3_t raw; }; #endif // HWY_ARCH_ARM_A64 template <> struct Tuple4 { uint8x16x4_t raw; }; template struct Tuple4 { uint8x8x4_t raw; }; template <> struct Tuple4 { int8x16x4_t raw; }; template struct Tuple4 { int8x8x4_t raw; }; template <> struct Tuple4 { uint16x8x4_t raw; }; template struct Tuple4 { uint16x4x4_t raw; }; template <> struct Tuple4 { int16x8x4_t raw; }; template struct Tuple4 { int16x4x4_t raw; }; template <> struct Tuple4 { uint32x4x4_t raw; }; template struct Tuple4 { uint32x2x4_t raw; }; template <> struct Tuple4 { int32x4x4_t raw; }; template struct Tuple4 { int32x2x4_t raw; }; template <> struct Tuple4 { uint64x2x4_t raw; }; template struct Tuple4 { uint64x1x4_t raw; }; template <> struct Tuple4 { int64x2x4_t raw; }; template struct Tuple4 { int64x1x4_t raw; }; template <> struct Tuple4 { uint16x8x4_t raw; }; template struct Tuple4 { uint16x4x4_t raw; }; template <> struct Tuple4 { uint16x8x4_t raw; }; template struct Tuple4 { uint16x4x4_t raw; }; template <> struct Tuple4 { float32x4x4_t raw; }; template struct Tuple4 { float32x2x4_t raw; }; #if HWY_ARCH_ARM_A64 template <> struct Tuple4 { float64x2x4_t raw; }; template struct Tuple4 { float64x1x4_t raw; }; #endif // HWY_ARCH_ARM_A64 template struct Raw128; // 128 template <> struct Raw128 { using type = uint8x16_t; }; template <> struct Raw128 { using type = uint16x8_t; }; template <> struct Raw128 { using type = uint32x4_t; }; template <> struct Raw128 { using type = uint64x2_t; }; template <> struct Raw128 { using type = int8x16_t; }; template <> struct Raw128 { using type = int16x8_t; }; template <> struct Raw128 { using type = int32x4_t; }; template <> struct Raw128 { using type = int64x2_t; }; template <> struct Raw128 { using type = uint16x8_t; }; template <> struct Raw128 { using type = uint16x8_t; }; template <> struct Raw128 { using type = float32x4_t; }; #if HWY_ARCH_ARM_A64 template <> struct Raw128 { using type = float64x2_t; }; #endif // 64 template <> struct Raw128 { using type = uint8x8_t; }; template <> struct Raw128 { using type = uint16x4_t; }; template <> struct Raw128 { using type = uint32x2_t; }; template <> struct Raw128 { using type = uint64x1_t; }; template <> struct Raw128 { using type = int8x8_t; }; template <> struct Raw128 { using type = int16x4_t; }; template <> struct Raw128 { using type = int32x2_t; }; template <> struct Raw128 { using type = int64x1_t; }; template <> struct Raw128 { using type = uint16x4_t; }; template <> struct Raw128 { using type = uint16x4_t; }; template <> struct Raw128 { using type = float32x2_t; }; #if HWY_ARCH_ARM_A64 template <> struct Raw128 { using type = float64x1_t; }; #endif // 32 (same as 64) template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; // 16 (same as 64) template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; // 8 (same as 64) template <> struct Raw128 : public Raw128 {}; template <> struct Raw128 : public Raw128 {}; } // namespace detail template class Vec128 { using Raw = typename detail::Raw128::type; public: using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = N; // only for DFromV HWY_INLINE Vec128() {} Vec128(const Vec128&) = default; Vec128& operator=(const Vec128&) = default; HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec128& operator*=(const Vec128 other) { return *this = (*this * other); } HWY_INLINE Vec128& operator/=(const Vec128 other) { return *this = (*this / other); } HWY_INLINE Vec128& operator+=(const Vec128 other) { return *this = (*this + other); } HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } HWY_INLINE Vec128& operator|=(const Vec128 other) { return *this = (*this | other); } HWY_INLINE Vec128& operator^=(const Vec128 other) { return *this = (*this ^ other); } Raw raw; }; template using Vec64 = Vec128; template using Vec32 = Vec128; // FF..FF or 0. template class Mask128 { // ARM C Language Extensions return and expect unsigned type. using Raw = typename detail::Raw128, N>::type; public: HWY_INLINE Mask128() {} Mask128(const Mask128&) = default; Mask128& operator=(const Mask128&) = default; HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} Raw raw; }; template using Mask64 = Mask128; template using DFromV = Simd; template using TFromV = typename V::PrivateT; // ------------------------------ BitCast namespace detail { // Converts from Vec128 to Vec128 using the // vreinterpret*_u8_*() set of functions. #define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 #define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ Vec128 #define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v #define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw // Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. template HWY_INLINE Vec128 BitCastToByte(Vec128 v) { return v; } HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) // Special cases for [b]float16_t, which have the same Raw as uint16_t. template HWY_INLINE Vec128 BitCastToByte(Vec128 v) { return BitCastToByte(Vec128(v.raw)); } template HWY_INLINE Vec128 BitCastToByte(Vec128 v) { return BitCastToByte(Vec128(v.raw)); } #undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 #undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 #undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 #undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return v; } // 64-bit or less: template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_s8_u8(v.raw)); } template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_u16_u8(v.raw)); } template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_s16_u8(v.raw)); } template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_u32_u8(v.raw)); } template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_s32_u8(v.raw)); } template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(vreinterpret_f32_u8(v.raw)); } HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, Vec128 v) { return Vec64(vreinterpret_u64_u8(v.raw)); } HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, Vec128 v) { return Vec64(vreinterpret_s64_u8(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, Vec128 v) { return Vec64(vreinterpret_f64_u8(v.raw)); } #endif // 128-bit full: HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_s8_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_u16_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_s16_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_u32_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_s32_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_f32_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_u64_u8(v.raw)); } HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_s64_u8(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, Vec128 v) { return Vec128(vreinterpretq_f64_u8(v.raw)); } #endif // Special cases for [b]float16_t, which have the same Raw as uint16_t. template HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, Vec128 v) { return Vec128(BitCastFromByte(Simd(), v).raw); } template HWY_INLINE Vec128 BitCastFromByte( Simd /* tag */, Vec128 v) { return Vec128(BitCastFromByte(Simd(), v).raw); } } // namespace detail template HWY_API Vec128 BitCast(Simd d, Vec128 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } // ------------------------------ Set // Returns a vector with all lanes set to "t". #define HWY_NEON_BUILD_TPL_HWY_SET1 #define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128 #define HWY_NEON_BUILD_PARAM_HWY_SET1(type, size) \ Simd /* tag */, const type##_t t #define HWY_NEON_BUILD_ARG_HWY_SET1 t HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) #undef HWY_NEON_BUILD_TPL_HWY_SET1 #undef HWY_NEON_BUILD_RET_HWY_SET1 #undef HWY_NEON_BUILD_PARAM_HWY_SET1 #undef HWY_NEON_BUILD_ARG_HWY_SET1 // Returns an all-zero vector. template HWY_API Vec128 Zero(Simd d) { return Set(d, 0); } template HWY_API Vec128 Zero(Simd /* tag */) { return Vec128(Zero(Simd()).raw); } template using VFromD = decltype(Zero(D())); HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") #if HWY_COMPILER_GCC_ACTUAL HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") #endif // Returns a vector with uninitialized elements. template HWY_API Vec128 Undefined(Simd /*d*/) { typename detail::Raw128::type a; return Vec128(a); } HWY_DIAGNOSTICS(pop) // Returns a vector with lane i=[0, N) set to "first" + i. template Vec128 Iota(const Simd d, const T2 first) { HWY_ALIGN T lanes[16 / sizeof(T)]; for (size_t i = 0; i < 16 / sizeof(T); ++i) { lanes[i] = AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); } return Load(d, lanes); } // ------------------------------ GetLane namespace detail { #define HWY_NEON_BUILD_TPL_HWY_GET template #define HWY_NEON_BUILD_RET_HWY_GET(type, size) type##_t #define HWY_NEON_BUILD_PARAM_HWY_GET(type, size) Vec128 v #define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) #undef HWY_NEON_BUILD_TPL_HWY_GET #undef HWY_NEON_BUILD_RET_HWY_GET #undef HWY_NEON_BUILD_PARAM_HWY_GET #undef HWY_NEON_BUILD_ARG_HWY_GET } // namespace detail template HWY_API TFromV GetLane(const V v) { return detail::GetLane<0>(v); } // ------------------------------ ExtractLane // Requires one overload per vector length because GetLane<3> is a compile error // if v is a uint32x2_t. template HWY_API T ExtractLane(const Vec128 v, size_t i) { HWY_DASSERT(i == 0); (void)i; return detail::GetLane<0>(v); } template HWY_API T ExtractLane(const Vec128 v, size_t i) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::GetLane<0>(v); case 1: return detail::GetLane<1>(v); } } #endif alignas(16) T lanes[2]; Store(v, DFromV(), lanes); return lanes[i]; } template HWY_API T ExtractLane(const Vec128 v, size_t i) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::GetLane<0>(v); case 1: return detail::GetLane<1>(v); case 2: return detail::GetLane<2>(v); case 3: return detail::GetLane<3>(v); } } #endif alignas(16) T lanes[4]; Store(v, DFromV(), lanes); return lanes[i]; } template HWY_API T ExtractLane(const Vec128 v, size_t i) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::GetLane<0>(v); case 1: return detail::GetLane<1>(v); case 2: return detail::GetLane<2>(v); case 3: return detail::GetLane<3>(v); case 4: return detail::GetLane<4>(v); case 5: return detail::GetLane<5>(v); case 6: return detail::GetLane<6>(v); case 7: return detail::GetLane<7>(v); } } #endif alignas(16) T lanes[8]; Store(v, DFromV(), lanes); return lanes[i]; } template HWY_API T ExtractLane(const Vec128 v, size_t i) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::GetLane<0>(v); case 1: return detail::GetLane<1>(v); case 2: return detail::GetLane<2>(v); case 3: return detail::GetLane<3>(v); case 4: return detail::GetLane<4>(v); case 5: return detail::GetLane<5>(v); case 6: return detail::GetLane<6>(v); case 7: return detail::GetLane<7>(v); case 8: return detail::GetLane<8>(v); case 9: return detail::GetLane<9>(v); case 10: return detail::GetLane<10>(v); case 11: return detail::GetLane<11>(v); case 12: return detail::GetLane<12>(v); case 13: return detail::GetLane<13>(v); case 14: return detail::GetLane<14>(v); case 15: return detail::GetLane<15>(v); } } #endif alignas(16) T lanes[16]; Store(v, DFromV(), lanes); return lanes[i]; } // ------------------------------ InsertLane namespace detail { #define HWY_NEON_BUILD_TPL_HWY_INSERT template #define HWY_NEON_BUILD_RET_HWY_INSERT(type, size) Vec128 #define HWY_NEON_BUILD_PARAM_HWY_INSERT(type, size) \ Vec128 v, type##_t t #define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) #undef HWY_NEON_BUILD_TPL_HWY_INSERT #undef HWY_NEON_BUILD_RET_HWY_INSERT #undef HWY_NEON_BUILD_PARAM_HWY_INSERT #undef HWY_NEON_BUILD_ARG_HWY_INSERT } // namespace detail // Requires one overload per vector length because InsertLane<3> may be a // compile error. template HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { HWY_DASSERT(i == 0); (void)i; return Set(DFromV(), t); } template HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::InsertLane<0>(v, t); case 1: return detail::InsertLane<1>(v, t); } } #endif const DFromV d; alignas(16) T lanes[2]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } template HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::InsertLane<0>(v, t); case 1: return detail::InsertLane<1>(v, t); case 2: return detail::InsertLane<2>(v, t); case 3: return detail::InsertLane<3>(v, t); } } #endif const DFromV d; alignas(16) T lanes[4]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } template HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::InsertLane<0>(v, t); case 1: return detail::InsertLane<1>(v, t); case 2: return detail::InsertLane<2>(v, t); case 3: return detail::InsertLane<3>(v, t); case 4: return detail::InsertLane<4>(v, t); case 5: return detail::InsertLane<5>(v, t); case 6: return detail::InsertLane<6>(v, t); case 7: return detail::InsertLane<7>(v, t); } } #endif const DFromV d; alignas(16) T lanes[8]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } template HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { #if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang if (__builtin_constant_p(i)) { switch (i) { case 0: return detail::InsertLane<0>(v, t); case 1: return detail::InsertLane<1>(v, t); case 2: return detail::InsertLane<2>(v, t); case 3: return detail::InsertLane<3>(v, t); case 4: return detail::InsertLane<4>(v, t); case 5: return detail::InsertLane<5>(v, t); case 6: return detail::InsertLane<6>(v, t); case 7: return detail::InsertLane<7>(v, t); case 8: return detail::InsertLane<8>(v, t); case 9: return detail::InsertLane<9>(v, t); case 10: return detail::InsertLane<10>(v, t); case 11: return detail::InsertLane<11>(v, t); case 12: return detail::InsertLane<12>(v, t); case 13: return detail::InsertLane<13>(v, t); case 14: return detail::InsertLane<14>(v, t); case 15: return detail::InsertLane<15>(v, t); } } #endif const DFromV d; alignas(16) T lanes[16]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } // ================================================== ARITHMETIC // ------------------------------ Addition HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) // ------------------------------ Subtraction HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) // ------------------------------ SumsOf8 HWY_API Vec128 SumsOf8(const Vec128 v) { return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); } HWY_API Vec64 SumsOf8(const Vec64 v) { return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); } // ------------------------------ SaturatedAdd // Only defined for uint8_t, uint16_t and their signed versions, as in other // architectures. // Returns a + b clamped to the destination range. HWY_NEON_DEF_FUNCTION_INT_8(SaturatedAdd, vqadd, _, 2) HWY_NEON_DEF_FUNCTION_INT_16(SaturatedAdd, vqadd, _, 2) HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedAdd, vqadd, _, 2) HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedAdd, vqadd, _, 2) // ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. HWY_NEON_DEF_FUNCTION_INT_8(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_INT_16(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedSub, vqsub, _, 2) // Not part of API, used in implementation. namespace detail { HWY_NEON_DEF_FUNCTION_UINT_32(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_UINT_64(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_INT_32(SaturatedSub, vqsub, _, 2) HWY_NEON_DEF_FUNCTION_INT_64(SaturatedSub, vqsub, _, 2) } // namespace detail // ------------------------------ Average // Returns (a + b + 1) / 2 HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2) HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) // ------------------------------ Neg HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below HWY_API Vec64 Neg(const Vec64 v) { #if HWY_ARCH_ARM_A64 return Vec64(vneg_s64(v.raw)); #else return Zero(Full64()) - v; #endif } HWY_API Vec128 Neg(const Vec128 v) { #if HWY_ARCH_ARM_A64 return Vec128(vnegq_s64(v.raw)); #else return Zero(Full128()) - v; #endif } // ------------------------------ ShiftLeft // Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). #pragma push_macro("HWY_NEON_DEF_FUNCTION") #undef HWY_NEON_DEF_FUNCTION #define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ template \ HWY_API Vec128 name(const Vec128 v) { \ return kBits == 0 ? v \ : Vec128(HWY_NEON_EVAL( \ prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ } HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) #pragma pop_macro("HWY_NEON_DEF_FUNCTION") // ------------------------------ RotateRight (ShiftRight, Or) template HWY_API Vec128 RotateRight(const Vec128 v) { static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } template HWY_API Vec128 RotateRight(const Vec128 v) { static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } // NOTE: vxarq_u64 can be applied to uint64_t, but we do not yet have a // mechanism for checking for extensions to ARMv8. // ------------------------------ Shl HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); } template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); } template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); } template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); } HWY_API Vec64 operator<<(const Vec64 v, const Vec64 bits) { return Vec64(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s8(v.raw, bits.raw)); } template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshl_s8(v.raw, bits.raw)); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s16(v.raw, bits.raw)); } template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshl_s16(v.raw, bits.raw)); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s32(v.raw, bits.raw)); } template HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshl_s32(v.raw, bits.raw)); } HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s64(v.raw, bits.raw)); } HWY_API Vec64 operator<<(const Vec64 v, const Vec64 bits) { return Vec64(vshl_s64(v.raw, bits.raw)); } // ------------------------------ Shr (Neg) HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int8x16_t neg_bits = Neg(BitCast(Full128(), bits)).raw; return Vec128(vshlq_u8(v.raw, neg_bits)); } template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int8x8_t neg_bits = Neg(BitCast(Simd(), bits)).raw; return Vec128(vshl_u8(v.raw, neg_bits)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int16x8_t neg_bits = Neg(BitCast(Full128(), bits)).raw; return Vec128(vshlq_u16(v.raw, neg_bits)); } template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int16x4_t neg_bits = Neg(BitCast(Simd(), bits)).raw; return Vec128(vshl_u16(v.raw, neg_bits)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int32x4_t neg_bits = Neg(BitCast(Full128(), bits)).raw; return Vec128(vshlq_u32(v.raw, neg_bits)); } template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int32x2_t neg_bits = Neg(BitCast(Simd(), bits)).raw; return Vec128(vshl_u32(v.raw, neg_bits)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { const int64x2_t neg_bits = Neg(BitCast(Full128(), bits)).raw; return Vec128(vshlq_u64(v.raw, neg_bits)); } HWY_API Vec64 operator>>(const Vec64 v, const Vec64 bits) { const int64x1_t neg_bits = Neg(BitCast(Full64(), bits)).raw; return Vec64(vshl_u64(v.raw, neg_bits)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s8(v.raw, Neg(bits).raw)); } template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshl_s8(v.raw, Neg(bits).raw)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s16(v.raw, Neg(bits).raw)); } template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshl_s16(v.raw, Neg(bits).raw)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s32(v.raw, Neg(bits).raw)); } template HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshl_s32(v.raw, Neg(bits).raw)); } HWY_API Vec128 operator>>(const Vec128 v, const Vec128 bits) { return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); } HWY_API Vec64 operator>>(const Vec64 v, const Vec64 bits) { return Vec64(vshl_s64(v.raw, Neg(bits).raw)); } // ------------------------------ ShiftLeftSame (Shl) template HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { return v << Set(Simd(), static_cast(bits)); } template HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { return v >> Set(Simd(), static_cast(bits)); } // ------------------------------ Integer multiplication // Unsigned HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmulq_u16(a.raw, b.raw)); } HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmulq_u32(a.raw, b.raw)); } template HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmul_u16(a.raw, b.raw)); } template HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmul_u32(a.raw, b.raw)); } // Signed HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmulq_s16(a.raw, b.raw)); } HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmulq_s32(a.raw, b.raw)); } template HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmul_s16(a.raw, b.raw)); } template HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { return Vec128(vmul_s32(a.raw, b.raw)); } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec128 MulHigh(const Vec128 a, const Vec128 b) { int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); #if HWY_ARCH_ARM_A64 int32x4_t rhi = vmull_high_s16(a.raw, b.raw); #else int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); #endif return Vec128( vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); } HWY_API Vec128 MulHigh(const Vec128 a, const Vec128 b) { uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); #if HWY_ARCH_ARM_A64 uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); #else uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); #endif return Vec128( vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); } template HWY_API Vec128 MulHigh(const Vec128 a, const Vec128 b) { int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); return Vec128(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); } template HWY_API Vec128 MulHigh(const Vec128 a, const Vec128 b) { uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); } HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { return Vec128(vqrdmulhq_s16(a.raw, b.raw)); } template HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { return Vec128(vqrdmulh_s16(a.raw, b.raw)); } // ------------------------------ Floating-point mul / div HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) // Approximate reciprocal HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { return Vec128(vrecpeq_f32(v.raw)); } template HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { return Vec128(vrecpe_f32(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) #else // Not defined on armv7: approximate namespace detail { HWY_INLINE Vec128 ReciprocalNewtonRaphsonStep( const Vec128 recip, const Vec128 divisor) { return Vec128(vrecpsq_f32(recip.raw, divisor.raw)); } template HWY_INLINE Vec128 ReciprocalNewtonRaphsonStep( const Vec128 recip, Vec128 divisor) { return Vec128(vrecps_f32(recip.raw, divisor.raw)); } } // namespace detail template HWY_API Vec128 operator/(const Vec128 a, const Vec128 b) { auto x = ApproximateReciprocal(b); x *= detail::ReciprocalNewtonRaphsonStep(x, b); x *= detail::ReciprocalNewtonRaphsonStep(x, b); x *= detail::ReciprocalNewtonRaphsonStep(x, b); return a * x; } #endif // ------------------------------ Absolute value of difference. HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { return Vec128(vabdq_f32(a.raw, b.raw)); } template HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { return Vec128(vabd_f32(a.raw, b.raw)); } // ------------------------------ Floating-point multiply-add variants // Returns add + mul * x #if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 template HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return Vec128(vfma_f32(add.raw, mul.raw, x.raw)); } HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return Vec128(vfmaq_f32(add.raw, mul.raw, x.raw)); } #else // Emulate FMA for floats. template HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return mul * x + add; } #endif #if HWY_ARCH_ARM_A64 HWY_API Vec64 MulAdd(const Vec64 mul, const Vec64 x, const Vec64 add) { return Vec64(vfma_f64(add.raw, mul.raw, x.raw)); } HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return Vec128(vfmaq_f64(add.raw, mul.raw, x.raw)); } #endif // Returns add - mul * x #if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 template HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return Vec128(vfms_f32(add.raw, mul.raw, x.raw)); } HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return Vec128(vfmsq_f32(add.raw, mul.raw, x.raw)); } #else // Emulate FMA for floats. template HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return add - mul * x; } #endif #if HWY_ARCH_ARM_A64 HWY_API Vec64 NegMulAdd(const Vec64 mul, const Vec64 x, const Vec64 add) { return Vec64(vfms_f64(add.raw, mul.raw, x.raw)); } HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, const Vec128 add) { return Vec128(vfmsq_f64(add.raw, mul.raw, x.raw)); } #endif // Returns mul * x - sub template HWY_API Vec128 MulSub(const Vec128 mul, const Vec128 x, const Vec128 sub) { return MulAdd(mul, x, Neg(sub)); } // Returns -mul * x - sub template HWY_API Vec128 NegMulSub(const Vec128 mul, const Vec128 x, const Vec128 sub) { return Neg(MulAdd(mul, x, sub)); } #if HWY_ARCH_ARM_A64 template HWY_API Vec128 MulSub(const Vec128 mul, const Vec128 x, const Vec128 sub) { return MulAdd(mul, x, Neg(sub)); } template HWY_API Vec128 NegMulSub(const Vec128 mul, const Vec128 x, const Vec128 sub) { return Neg(MulAdd(mul, x, sub)); } #endif // ------------------------------ Floating-point square root (IfThenZeroElse) // Approximate reciprocal square root HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { return Vec128(vrsqrteq_f32(v.raw)); } template HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { return Vec128(vrsqrte_f32(v.raw)); } // Full precision square root #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) #else namespace detail { HWY_INLINE Vec128 ReciprocalSqrtStep(const Vec128 root, const Vec128 recip) { return Vec128(vrsqrtsq_f32(root.raw, recip.raw)); } template HWY_INLINE Vec128 ReciprocalSqrtStep(const Vec128 root, Vec128 recip) { return Vec128(vrsqrts_f32(root.raw, recip.raw)); } } // namespace detail // Not defined on armv7: approximate template HWY_API Vec128 Sqrt(const Vec128 v) { auto recip = ApproximateReciprocalSqrt(v); recip *= detail::ReciprocalSqrtStep(v * recip, recip); recip *= detail::ReciprocalSqrtStep(v * recip, recip); recip *= detail::ReciprocalSqrtStep(v * recip, recip); const auto root = v * recip; return IfThenZeroElse(v == Zero(Simd()), root); } #endif // ================================================== LOGICAL // ------------------------------ Not // There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. template HWY_API Vec128 Not(const Vec128 v) { const Full128 d; const Repartition d8; return BitCast(d, Vec128(vmvnq_u8(BitCast(d8, v).raw))); } template HWY_API Vec128 Not(const Vec128 v) { const Simd d; const Repartition d8; using V8 = decltype(Zero(d8)); return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); } // ------------------------------ And HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) // Uses the u32/64 defined above. template HWY_API Vec128 And(const Vec128 a, const Vec128 b) { const DFromV d; const RebindToUnsigned du; return BitCast(d, BitCast(du, a) & BitCast(du, b)); } // ------------------------------ AndNot namespace detail { // reversed_andnot returns a & ~b. HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) } // namespace detail // Returns ~not_mask & mask. template HWY_API Vec128 AndNot(const Vec128 not_mask, const Vec128 mask) { return detail::reversed_andnot(mask, not_mask); } // Uses the u32/64 defined above. template HWY_API Vec128 AndNot(const Vec128 not_mask, const Vec128 mask) { const DFromV d; const RebindToUnsigned du; VFromD ret = detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); return BitCast(d, ret); } // ------------------------------ Or HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) // Uses the u32/64 defined above. template HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { const DFromV d; const RebindToUnsigned du; return BitCast(d, BitCast(du, a) | BitCast(du, b)); } // ------------------------------ Xor HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) // Uses the u32/64 defined above. template HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { const DFromV d; const RebindToUnsigned du; return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); } // ------------------------------ Xor3 #if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3) HWY_NEON_DEF_FUNCTION_FULL_UI(Xor3, veor3, _, 3) // Half vectors are not natively supported. Two Xor are likely more efficient // than Combine to 128-bit. template HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { return Xor(x1, Xor(x2, x3)); } template HWY_API Vec128 Xor3(const Vec128 x1, const Vec128 x2, const Vec128 x3) { const DFromV d; const RebindToUnsigned du; return BitCast(d, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); } #else template HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { return Xor(x1, Xor(x2, x3)); } #endif // ------------------------------ Or3 template HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { return Or(o1, Or(o2, o3)); } // ------------------------------ OrAnd template HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { return Or(o, And(a1, a2)); } // ------------------------------ IfVecThenElse template HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, Vec128 no) { return IfThenElse(MaskFromVec(mask), yes, no); } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { return And(a, b); } template HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { return Or(a, b); } template HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { return Xor(a, b); } // ------------------------------ PopulationCount #ifdef HWY_NATIVE_POPCNT #undef HWY_NATIVE_POPCNT #else #define HWY_NATIVE_POPCNT #endif namespace detail { template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { const Full128 d8; return Vec128(vcntq_u8(BitCast(d8, v).raw)); } template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { const Simd d8; return Vec128(vcnt_u8(BitCast(d8, v).raw)); } // ARM lacks popcount for lane sizes > 1, so take pairwise sums of the bytes. template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { const Full128 d8; const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); return Vec128(vpaddlq_u8(bytes)); } template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { const Repartition> d8; const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); return Vec128(vpaddl_u8(bytes)); } template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { const Full128 d8; const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); return Vec128(vpaddlq_u16(vpaddlq_u8(bytes))); } template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { const Repartition> d8; const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); return Vec128(vpaddl_u16(vpaddl_u8(bytes))); } template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { const Full128 d8; const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(bytes)))); } template HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { const Repartition> d8; const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); return Vec128(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); } } // namespace detail template HWY_API Vec128 PopulationCount(Vec128 v) { return detail::PopulationCount(hwy::SizeTag(), v); } // ================================================== SIGN // ------------------------------ Abs // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabsq_s8(v.raw)); } HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabsq_s16(v.raw)); } HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabsq_s32(v.raw)); } // i64 is implemented after BroadcastSignBit. HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabsq_f32(v.raw)); } template HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabs_s8(v.raw)); } template HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabs_s16(v.raw)); } template HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabs_s32(v.raw)); } template HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabs_f32(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_API Vec128 Abs(const Vec128 v) { return Vec128(vabsq_f64(v.raw)); } HWY_API Vec64 Abs(const Vec64 v) { return Vec64(vabs_f64(v.raw)); } #endif // ------------------------------ CopySign template HWY_API Vec128 CopySign(const Vec128 magn, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const auto msb = SignBit(Simd()); return Or(AndNot(msb, magn), And(msb, sign)); } template HWY_API Vec128 CopySignToAbs(const Vec128 abs, const Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); return Or(abs, And(SignBit(Simd()), sign)); } // ------------------------------ BroadcastSignBit template HWY_API Vec128 BroadcastSignBit(const Vec128 v) { return ShiftRight(v); } // ================================================== MASK // ------------------------------ To/from vector // Mask and Vec have the same representation (true = FF..FF). template HWY_API Mask128 MaskFromVec(const Vec128 v) { const Simd, N, 0> du; return Mask128(BitCast(du, v).raw); } template HWY_API Vec128 VecFromMask(Simd d, const Mask128 v) { return BitCast(d, Vec128, N>(v.raw)); } // ------------------------------ RebindMask template HWY_API Mask128 RebindMask(Simd dto, Mask128 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return MaskFromVec(BitCast(dto, VecFromMask(Simd(), m))); } // ------------------------------ IfThenElse(mask, yes, no) = mask ? b : a. #define HWY_NEON_BUILD_TPL_HWY_IF #define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 #define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ const Mask128 mask, const Vec128 yes, \ const Vec128 no #define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) #undef HWY_NEON_BUILD_TPL_HWY_IF #undef HWY_NEON_BUILD_RET_HWY_IF #undef HWY_NEON_BUILD_PARAM_HWY_IF #undef HWY_NEON_BUILD_ARG_HWY_IF // mask ? yes : 0 template HWY_API Vec128 IfThenElseZero(const Mask128 mask, const Vec128 yes) { return yes & VecFromMask(Simd(), mask); } // mask ? 0 : no template HWY_API Vec128 IfThenZeroElse(const Mask128 mask, const Vec128 no) { return AndNot(VecFromMask(Simd(), mask), no); } template HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, Vec128 no) { static_assert(IsSigned(), "Only works for signed/float"); const Simd d; const RebindToSigned di; Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); return IfThenElse(m, yes, no); } template HWY_API Vec128 ZeroIfNegative(Vec128 v) { const auto zero = Zero(Simd()); return Max(zero, v); } // ------------------------------ Mask logical template HWY_API Mask128 Not(const Mask128 m) { return MaskFromVec(Not(VecFromMask(Simd(), m))); } template HWY_API Mask128 And(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Or(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } // ================================================== COMPARE // Comparisons fill a lane with 1-bits if the condition is true, else 0. // ------------------------------ Shuffle2301 (for i64 compares) // Swap 32-bit halves in 64-bits HWY_API Vec64 Shuffle2301(const Vec64 v) { return Vec64(vrev64_u32(v.raw)); } HWY_API Vec64 Shuffle2301(const Vec64 v) { return Vec64(vrev64_s32(v.raw)); } HWY_API Vec64 Shuffle2301(const Vec64 v) { return Vec64(vrev64_f32(v.raw)); } HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128(vrev64q_u32(v.raw)); } HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128(vrev64q_s32(v.raw)); } HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128(vrev64q_f32(v.raw)); } #define HWY_NEON_BUILD_TPL_HWY_COMPARE #define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 #define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ const Vec128 a, const Vec128 b #define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw // ------------------------------ Equality HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) #else // No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) #endif // ------------------------------ Strict inequality (signed, float) #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<, vclt, _, HWY_COMPARE) #else HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<, vclt, _, HWY_COMPARE) HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) #endif HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) // ------------------------------ Weak inequality (float) HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) #undef HWY_NEON_BUILD_TPL_HWY_COMPARE #undef HWY_NEON_BUILD_RET_HWY_COMPARE #undef HWY_NEON_BUILD_PARAM_HWY_COMPARE #undef HWY_NEON_BUILD_ARG_HWY_COMPARE // ------------------------------ ARMv7 i64 compare (Shuffle2301, Eq) #if HWY_ARCH_ARM_V7 template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { const Simd d32; const Simd d64; const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); const auto cmp64 = cmp32 & Shuffle2301(cmp32); return MaskFromVec(BitCast(d64, cmp64)); } template HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { const Simd d32; const Simd d64; const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); const auto cmp64 = cmp32 & Shuffle2301(cmp32); return MaskFromVec(BitCast(d64, cmp64)); } HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { const int64x2_t sub = vqsubq_s64(a.raw, b.raw); return MaskFromVec(BroadcastSignBit(Vec128(sub))); } HWY_API Mask128 operator<(const Vec64 a, const Vec64 b) { const int64x1_t sub = vqsub_s64(a.raw, b.raw); return MaskFromVec(BroadcastSignBit(Vec64(sub))); } template HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { const DFromV du; const RebindToSigned di; const Vec128 msb = AndNot(a, b) | AndNot(a ^ b, a - b); return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); } #endif // ------------------------------ operator!= (operator==) // Customize HWY_NEON_DEF_FUNCTION to call 2 functions. #pragma push_macro("HWY_NEON_DEF_FUNCTION") #undef HWY_NEON_DEF_FUNCTION // This cannot have _any_ template argument (in x86_128 we can at least have N // as an argument), otherwise it is not more specialized than rewritten // operator== in C++20, leading to compile errors. #define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ HWY_API Mask128 name(Vec128 a, \ Vec128 b) { \ return Not(a == b); \ } HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator!=, ignored, ignored, ignored) #pragma pop_macro("HWY_NEON_DEF_FUNCTION") // ------------------------------ Reversed comparisons template HWY_API Mask128 operator>(Vec128 a, Vec128 b) { return operator<(b, a); } template HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { return operator<=(b, a); } // ------------------------------ FirstN (Iota, Lt) template HWY_API Mask128 FirstN(const Simd d, size_t num) { const RebindToSigned di; // Signed comparisons are cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); } // ------------------------------ TestBit (Eq) #define HWY_NEON_BUILD_TPL_HWY_TESTBIT #define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 #define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ Vec128 v, Vec128 bit #define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_INTS_UINTS(TestBit, vtst, _, HWY_TESTBIT) #else // No 64-bit versions on armv7 HWY_NEON_DEF_FUNCTION_UINT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) HWY_NEON_DEF_FUNCTION_INT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) template HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { return (v & bit) == bit; } template HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { return (v & bit) == bit; } #endif #undef HWY_NEON_BUILD_TPL_HWY_TESTBIT #undef HWY_NEON_BUILD_RET_HWY_TESTBIT #undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT #undef HWY_NEON_BUILD_ARG_HWY_TESTBIT // ------------------------------ Abs i64 (IfThenElse, BroadcastSignBit) HWY_API Vec128 Abs(const Vec128 v) { #if HWY_ARCH_ARM_A64 return Vec128(vabsq_s64(v.raw)); #else const auto zero = Zero(Full128()); return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } HWY_API Vec64 Abs(const Vec64 v) { #if HWY_ARCH_ARM_A64 return Vec64(vabs_s64(v.raw)); #else const auto zero = Zero(Full64()); return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } // ------------------------------ Min (IfThenElse, BroadcastSignBit) // Unsigned HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) template HWY_API Vec128 Min(const Vec128 a, const Vec128 b) { #if HWY_ARCH_ARM_A64 return IfThenElse(b < a, b, a); #else const DFromV du; const RebindToSigned di; return BitCast(du, BitCast(di, a) - BitCast(di, detail::SaturatedSub(a, b))); #endif } // Signed HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) template HWY_API Vec128 Min(const Vec128 a, const Vec128 b) { #if HWY_ARCH_ARM_A64 return IfThenElse(b < a, b, a); #else const Vec128 sign = detail::SaturatedSub(a, b); return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); #endif } // Float: IEEE minimumNumber on v8, otherwise NaN if any is NaN. #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vminnm, _, 2) #else HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) #endif // ------------------------------ Max (IfThenElse, BroadcastSignBit) // Unsigned (no u64) HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) template HWY_API Vec128 Max(const Vec128 a, const Vec128 b) { #if HWY_ARCH_ARM_A64 return IfThenElse(b < a, a, b); #else const DFromV du; const RebindToSigned di; return BitCast(du, BitCast(di, b) + BitCast(di, detail::SaturatedSub(a, b))); #endif } // Signed (no i64) HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) template HWY_API Vec128 Max(const Vec128 a, const Vec128 b) { #if HWY_ARCH_ARM_A64 return IfThenElse(b < a, a, b); #else const Vec128 sign = detail::SaturatedSub(a, b); return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); #endif } // Float: IEEE maximumNumber on v8, otherwise NaN if any is NaN. #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmaxnm, _, 2) #else HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) #endif // ================================================== MEMORY // ------------------------------ Load 128 HWY_API Vec128 LoadU(Full128 /* tag */, const uint8_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_u8(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const uint16_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_u16(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const uint32_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_u32(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const uint64_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_u64(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const int8_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_s8(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const int16_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_s16(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const int32_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_s32(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const int64_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_s64(unaligned)); } HWY_API Vec128 LoadU(Full128 /* tag */, const float* HWY_RESTRICT unaligned) { return Vec128(vld1q_f32(unaligned)); } #if HWY_ARCH_ARM_A64 HWY_API Vec128 LoadU(Full128 /* tag */, const double* HWY_RESTRICT unaligned) { return Vec128(vld1q_f64(unaligned)); } #endif // ------------------------------ Load 64 HWY_API Vec64 LoadU(Full64 /* tag */, const uint8_t* HWY_RESTRICT p) { return Vec64(vld1_u8(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const uint16_t* HWY_RESTRICT p) { return Vec64(vld1_u16(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const uint32_t* HWY_RESTRICT p) { return Vec64(vld1_u32(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const uint64_t* HWY_RESTRICT p) { return Vec64(vld1_u64(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const int8_t* HWY_RESTRICT p) { return Vec64(vld1_s8(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const int16_t* HWY_RESTRICT p) { return Vec64(vld1_s16(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const int32_t* HWY_RESTRICT p) { return Vec64(vld1_s32(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const int64_t* HWY_RESTRICT p) { return Vec64(vld1_s64(p)); } HWY_API Vec64 LoadU(Full64 /* tag */, const float* HWY_RESTRICT p) { return Vec64(vld1_f32(p)); } #if HWY_ARCH_ARM_A64 HWY_API Vec64 LoadU(Full64 /* tag */, const double* HWY_RESTRICT p) { return Vec64(vld1_f64(p)); } #endif // ------------------------------ Load 32 // Actual 32-bit broadcast load - used to implement the other lane types // because reinterpret_cast of the pointer leads to incorrect codegen on GCC. HWY_API Vec32 LoadU(Full32 /*tag*/, const uint32_t* HWY_RESTRICT p) { return Vec32(vld1_dup_u32(p)); } HWY_API Vec32 LoadU(Full32 /*tag*/, const int32_t* HWY_RESTRICT p) { return Vec32(vld1_dup_s32(p)); } HWY_API Vec32 LoadU(Full32 /*tag*/, const float* HWY_RESTRICT p) { return Vec32(vld1_dup_f32(p)); } template // 1 or 2 bytes HWY_API Vec32 LoadU(Full32 d, const T* HWY_RESTRICT p) { const Repartition d32; uint32_t buf; CopyBytes<4>(p, &buf); return BitCast(d, LoadU(d32, &buf)); } // ------------------------------ Load 16 // Actual 16-bit broadcast load - used to implement the other lane types // because reinterpret_cast of the pointer leads to incorrect codegen on GCC. HWY_API Vec128 LoadU(Simd /*tag*/, const uint16_t* HWY_RESTRICT p) { return Vec128(vld1_dup_u16(p)); } HWY_API Vec128 LoadU(Simd /*tag*/, const int16_t* HWY_RESTRICT p) { return Vec128(vld1_dup_s16(p)); } template HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { const Repartition d16; uint16_t buf; CopyBytes<2>(p, &buf); return BitCast(d, LoadU(d16, &buf)); } // ------------------------------ Load 8 HWY_API Vec128 LoadU(Simd, const uint8_t* HWY_RESTRICT p) { return Vec128(vld1_dup_u8(p)); } HWY_API Vec128 LoadU(Simd, const int8_t* HWY_RESTRICT p) { return Vec128(vld1_dup_s8(p)); } // [b]float16_t use the same Raw as uint16_t, so forward to that. template HWY_API Vec128 LoadU(Simd d, const float16_t* HWY_RESTRICT p) { const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return Vec128(LoadU(du16, pu16).raw); } template HWY_API Vec128 LoadU(Simd d, const bfloat16_t* HWY_RESTRICT p) { const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return Vec128(LoadU(du16, pu16).raw); } // On ARM, Load is the same as LoadU. template HWY_API Vec128 Load(Simd d, const T* HWY_RESTRICT p) { return LoadU(d, p); } template HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } // 128-bit SIMD => nothing to duplicate, same as an unaligned load. template HWY_API Vec128 LoadDup128(Simd d, const T* const HWY_RESTRICT p) { return LoadU(d, p); } // ------------------------------ Store 128 HWY_API void StoreU(const Vec128 v, Full128 /* tag */, uint8_t* HWY_RESTRICT unaligned) { vst1q_u8(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, uint16_t* HWY_RESTRICT unaligned) { vst1q_u16(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, uint32_t* HWY_RESTRICT unaligned) { vst1q_u32(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, uint64_t* HWY_RESTRICT unaligned) { vst1q_u64(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, int8_t* HWY_RESTRICT unaligned) { vst1q_s8(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, int16_t* HWY_RESTRICT unaligned) { vst1q_s16(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, int32_t* HWY_RESTRICT unaligned) { vst1q_s32(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, int64_t* HWY_RESTRICT unaligned) { vst1q_s64(unaligned, v.raw); } HWY_API void StoreU(const Vec128 v, Full128 /* tag */, float* HWY_RESTRICT unaligned) { vst1q_f32(unaligned, v.raw); } #if HWY_ARCH_ARM_A64 HWY_API void StoreU(const Vec128 v, Full128 /* tag */, double* HWY_RESTRICT unaligned) { vst1q_f64(unaligned, v.raw); } #endif // ------------------------------ Store 64 HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint8_t* HWY_RESTRICT p) { vst1_u8(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint16_t* HWY_RESTRICT p) { vst1_u16(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint32_t* HWY_RESTRICT p) { vst1_u32(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, uint64_t* HWY_RESTRICT p) { vst1_u64(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int8_t* HWY_RESTRICT p) { vst1_s8(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int16_t* HWY_RESTRICT p) { vst1_s16(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int32_t* HWY_RESTRICT p) { vst1_s32(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, int64_t* HWY_RESTRICT p) { vst1_s64(p, v.raw); } HWY_API void StoreU(const Vec64 v, Full64 /* tag */, float* HWY_RESTRICT p) { vst1_f32(p, v.raw); } #if HWY_ARCH_ARM_A64 HWY_API void StoreU(const Vec64 v, Full64 /* tag */, double* HWY_RESTRICT p) { vst1_f64(p, v.raw); } #endif // ------------------------------ Store 32 HWY_API void StoreU(const Vec32 v, Full32, uint32_t* HWY_RESTRICT p) { vst1_lane_u32(p, v.raw, 0); } HWY_API void StoreU(const Vec32 v, Full32, int32_t* HWY_RESTRICT p) { vst1_lane_s32(p, v.raw, 0); } HWY_API void StoreU(const Vec32 v, Full32, float* HWY_RESTRICT p) { vst1_lane_f32(p, v.raw, 0); } template // 1 or 2 bytes HWY_API void StoreU(const Vec32 v, Full32 d, T* HWY_RESTRICT p) { const Repartition d32; const uint32_t buf = GetLane(BitCast(d32, v)); CopyBytes<4>(&buf, p); } // ------------------------------ Store 16 HWY_API void StoreU(const Vec128 v, Simd, uint16_t* HWY_RESTRICT p) { vst1_lane_u16(p, v.raw, 0); } HWY_API void StoreU(const Vec128 v, Simd, int16_t* HWY_RESTRICT p) { vst1_lane_s16(p, v.raw, 0); } template HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { const Repartition d16; const uint16_t buf = GetLane(BitCast(d16, v)); CopyBytes<2>(&buf, p); } // ------------------------------ Store 8 HWY_API void StoreU(const Vec128 v, Simd, uint8_t* HWY_RESTRICT p) { vst1_lane_u8(p, v.raw, 0); } HWY_API void StoreU(const Vec128 v, Simd, int8_t* HWY_RESTRICT p) { vst1_lane_s8(p, v.raw, 0); } // [b]float16_t use the same Raw as uint16_t, so forward to that. template HWY_API void StoreU(Vec128 v, Simd d, float16_t* HWY_RESTRICT p) { const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return StoreU(Vec128(v.raw), du16, pu16); } template HWY_API void StoreU(Vec128 v, Simd d, bfloat16_t* HWY_RESTRICT p) { const RebindToUnsigned du16; const auto pu16 = reinterpret_cast(p); return StoreU(Vec128(v.raw), du16, pu16); } HWY_DIAGNOSTICS(push) #if HWY_COMPILER_GCC_ACTUAL HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") #endif // On ARM, Store is the same as StoreU. template HWY_API void Store(Vec128 v, Simd d, T* HWY_RESTRICT aligned) { StoreU(v, d, aligned); } HWY_DIAGNOSTICS(pop) template HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, T* HWY_RESTRICT p) { // Treat as unsigned so that we correctly support float16. const RebindToUnsigned du; const auto blended = IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); StoreU(BitCast(d, blended), d, p); } // ------------------------------ Non-temporal stores // Same as aligned stores on non-x86. template HWY_API void Stream(const Vec128 v, Simd d, T* HWY_RESTRICT aligned) { Store(v, d, aligned); } // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) // Unsigned: zero-extend to full vector. HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vmovl_u8(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec32 v) { uint16x8_t a = vmovl_u8(v.raw); return Vec128(vmovl_u16(vget_low_u16(a))); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vmovl_u16(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vmovl_u32(v.raw)); } HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { return BitCast(d, Vec128(vmovl_u8(v.raw))); } HWY_API Vec128 PromoteTo(Full128 d, const Vec32 v) { uint16x8_t a = vmovl_u8(v.raw); return BitCast(d, Vec128(vmovl_u16(vget_low_u16(a)))); } HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { return BitCast(d, Vec128(vmovl_u16(v.raw))); } // Unsigned: zero-extend to half vector. template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_u16(vmovl_u8(v.raw))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { uint16x8_t a = vmovl_u8(v.raw); return Vec128(vget_low_u32(vmovl_u16(vget_low_u16(a)))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_u32(vmovl_u16(v.raw))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_u64(vmovl_u32(v.raw))); } template HWY_API Vec128 PromoteTo(Simd d, const Vec128 v) { return BitCast(d, Vec128(vget_low_u16(vmovl_u8(v.raw)))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { uint16x8_t a = vmovl_u8(v.raw); uint32x4_t b = vmovl_u16(vget_low_u16(a)); return Vec128(vget_low_s32(vreinterpretq_s32_u32(b))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { uint32x4_t a = vmovl_u16(v.raw); return Vec128(vget_low_s32(vreinterpretq_s32_u32(a))); } // Signed: replicate sign bit to full vector. HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vmovl_s8(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec32 v) { int16x8_t a = vmovl_s8(v.raw); return Vec128(vmovl_s16(vget_low_s16(a))); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vmovl_s16(v.raw)); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vmovl_s32(v.raw)); } // Signed: replicate sign bit to half vector. template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_s16(vmovl_s8(v.raw))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { int16x8_t a = vmovl_s8(v.raw); int32x4_t b = vmovl_s16(vget_low_s16(a)); return Vec128(vget_low_s32(b)); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_s32(vmovl_s16(v.raw))); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vget_low_s64(vmovl_s32(v.raw))); } #if __ARM_FP & 2 HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec128 v) { const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); return Vec128(f32); } template HWY_API Vec128 PromoteTo(Simd /* tag */, const Vec128 v) { const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); return Vec128(vget_low_f32(f32)); } #else template HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { const RebindToSigned di32; const RebindToUnsigned du32; // Expand to u32 so we can shift. const auto bits16 = PromoteTo(du32, Vec128{v.raw}); const auto sign = ShiftRight<15>(bits16); const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); const auto mantissa = bits16 & Set(du32, 0x3FF); const auto subnormal = BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * Set(df32, 1.0f / 16384 / 1024)); const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); return BitCast(df32, ShiftLeft<31>(sign) | bits32); } #endif #if HWY_ARCH_ARM_A64 HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { return Vec128(vcvt_f64_f32(v.raw)); } HWY_API Vec64 PromoteTo(Full64 /* tag */, const Vec32 v) { return Vec64(vget_low_f64(vcvt_f64_f32(v.raw))); } HWY_API Vec128 PromoteTo(Full128 /* tag */, const Vec64 v) { const int64x2_t i64 = vmovl_s32(v.raw); return Vec128(vcvtq_f64_s64(i64)); } HWY_API Vec64 PromoteTo(Full64 /* tag */, const Vec32 v) { const int64x1_t i64 = vget_low_s64(vmovl_s32(v.raw)); return Vec64(vcvt_f64_s64(i64)); } #endif // ------------------------------ Demotions (full -> part w/ narrow lanes) // From full vector to half or quarter HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec64(vqmovun_s32(v.raw)); } HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec64(vqmovn_s32(v.raw)); } HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec128 v) { const uint16x4_t a = vqmovun_s32(v.raw); return Vec32(vqmovn_u16(vcombine_u16(a, a))); } HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec64(vqmovun_s16(v.raw)); } HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec128 v) { const int16x4_t a = vqmovn_s32(v.raw); return Vec32(vqmovn_s16(vcombine_s16(a, a))); } HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec64(vqmovn_s16(v.raw)); } // From half vector to partial half template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovun_s32(vcombine_s32(v.raw, v.raw))); } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovn_s32(vcombine_s32(v.raw, v.raw))); } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); return Vec128(vqmovn_u16(vcombine_u16(a, a))); } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovun_s16(vcombine_s16(v.raw, v.raw))); } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); return Vec128(vqmovn_s16(vcombine_s16(a, a))); } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { return Vec128(vqmovn_s16(vcombine_s16(v.raw, v.raw))); } #if __ARM_FP & 2 HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec128{vreinterpret_u16_f16(vcvt_f16_f32(v.raw))}; } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { const float16x4_t f16 = vcvt_f16_f32(vcombine_f32(v.raw, v.raw)); return Vec128(vreinterpret_u16_f16(f16)); } #else template HWY_API Vec128 DemoteTo(Simd df16, const Vec128 v) { const RebindToUnsigned du16; const Rebind du; const RebindToSigned di; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); const auto k15 = Set(di, 15); const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); const auto is_tiny = exp < Set(di, -24); const auto is_subnormal = exp < Set(di, -14); const auto biased_exp16 = BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + (mantissa32 >> (Set(du, 13) + sub_exp)); const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, ShiftRight<13>(mantissa32)); // <1024 const auto sign16 = ShiftLeft<15>(sign); const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); return Vec128(DemoteTo(du16, bits16).raw); } #endif template HWY_API Vec128 DemoteTo(Simd dbf16, const Vec128 v) { const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } #if HWY_ARCH_ARM_A64 HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { return Vec64(vcvt_f32_f64(v.raw)); } HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec64 v) { return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); } HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { const int64x2_t i64 = vcvtq_s64_f64(v.raw); return Vec64(vqmovn_s64(i64)); } HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec64 v) { const int64x1_t i64 = vcvt_s64_f64(v.raw); // There is no i64x1 -> i32x1 narrow, so expand to int64x2_t first. const int64x2_t i64x2 = vcombine_s64(i64, i64); return Vec32(vqmovn_s64(i64x2)); } #endif HWY_API Vec32 U8FromU32(const Vec128 v) { const uint8x16_t org_v = detail::BitCastToByte(v).raw; const uint8x16_t w = vuzp1q_u8(org_v, org_v); return Vec32(vget_low_u8(vuzp1q_u8(w, w))); } template HWY_API Vec128 U8FromU32(const Vec128 v) { const uint8x8_t org_v = detail::BitCastToByte(v).raw; const uint8x8_t w = vuzp1_u8(org_v, org_v); return Vec128(vuzp1_u8(w, w)); } // In the following DemoteTo functions, |b| is purposely undefined. // The value a needs to be extended to 128 bits so that vqmovn can be // used and |b| is undefined so that no extra overhead is introduced. HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { Vec128 a = DemoteTo(Simd(), v); Vec128 b; uint16x8_t c = vcombine_u16(a.raw, b.raw); return Vec128(vqmovn_u16(c)); } template HWY_API Vec128 DemoteTo(Simd /* tag */, const Vec128 v) { Vec128 a = DemoteTo(Simd(), v); Vec128 b; int16x8_t c = vcombine_s16(a.raw, b.raw); return Vec128(vqmovn_s16(c)); } HWY_DIAGNOSTICS(pop) // ------------------------------ Convert integer <=> floating-point HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_f32_s32(v.raw)); } template HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128(vcvt_f32_s32(v.raw)); } HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_f32_u32(v.raw)); } template HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128(vcvt_f32_u32(v.raw)); } // Truncates (rounds toward zero). HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_s32_f32(v.raw)); } template HWY_API Vec128 ConvertTo(Simd /* tag */, const Vec128 v) { return Vec128(vcvt_s32_f32(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_f64_s64(v.raw)); } HWY_API Vec64 ConvertTo(Full64 /* tag */, const Vec64 v) { return Vec64(vcvt_f64_s64(v.raw)); } HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_f64_u64(v.raw)); } HWY_API Vec64 ConvertTo(Full64 /* tag */, const Vec64 v) { return Vec64(vcvt_f64_u64(v.raw)); } // Truncates (rounds toward zero). HWY_API Vec128 ConvertTo(Full128 /* tag */, const Vec128 v) { return Vec128(vcvtq_s64_f64(v.raw)); } HWY_API Vec64 ConvertTo(Full64 /* tag */, const Vec64 v) { return Vec64(vcvt_s64_f64(v.raw)); } #endif // ------------------------------ Round (IfThenElse, mask, logical) #if HWY_ARCH_ARM_A64 // Toward nearest integer HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) // Toward zero, aka truncate HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) // Toward +infinity, aka ceiling HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) // Toward -infinity, aka floor HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) #else // ------------------------------ Trunc // ARMv7 only supports truncation to integer. We can either convert back to // float (3 floating-point and 2 logic operations) or manipulate the binary32 // representation, clearing the lowest 23-exp mantissa bits. This requires 9 // integer operations and 3 constants, which is likely more expensive. namespace detail { // The original value is already the desired result if NaN or the magnitude is // large (i.e. the value is already an integer). template HWY_INLINE Mask128 UseInt(const Vec128 v) { return Abs(v) < Set(Simd(), MantissaEnd()); } } // namespace detail template HWY_API Vec128 Trunc(const Vec128 v) { const DFromV df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); return IfThenElse(detail::UseInt(v), int_f, v); } template HWY_API Vec128 Round(const Vec128 v) { const DFromV df; // ARMv7 also lacks a native NearestInt, but we can instead rely on rounding // (we assume the current mode is nearest-even) after addition with a large // value such that no mantissa bits remain. We may need a compiler flag for // precise floating-point to prevent this from being "optimized" out. const auto max = Set(df, MantissaEnd()); const auto large = CopySignToAbs(max, v); const auto added = large + v; const auto rounded = added - large; // Keep original if NaN or the magnitude is large (already an int). return IfThenElse(Abs(v) < max, rounded, v); } template HWY_API Vec128 Ceil(const Vec128 v) { const DFromV df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); // Truncating a positive non-integer ends up smaller; if so, add 1. const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); return IfThenElse(detail::UseInt(v), int_f - neg1, v); } template HWY_API Vec128 Floor(const Vec128 v) { const DFromV df; const RebindToSigned di; const auto integer = ConvertTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); // Truncating a negative non-integer ends up larger; if so, subtract 1. const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); return IfThenElse(detail::UseInt(v), int_f + neg1, v); } #endif // ------------------------------ NearestInt (Round) #if HWY_ARCH_ARM_A64 HWY_API Vec128 NearestInt(const Vec128 v) { return Vec128(vcvtnq_s32_f32(v.raw)); } template HWY_API Vec128 NearestInt(const Vec128 v) { return Vec128(vcvtn_s32_f32(v.raw)); } #else template HWY_API Vec128 NearestInt(const Vec128 v) { const RebindToSigned> di; return ConvertTo(di, Round(v)); } #endif // ------------------------------ Floating-point classification template HWY_API Mask128 IsNaN(const Vec128 v) { return v != v; } template HWY_API Mask128 IsInf(const Vec128 v) { const Simd d; const RebindToSigned di; const VFromD vi = BitCast(di, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. template HWY_API Mask128 IsFinite(const Vec128 v) { const Simd d; const RebindToUnsigned du; const RebindToSigned di; // cheaper than unsigned comparison const VFromD vu = BitCast(du, v); // 'Shift left' to clear the sign bit, then right so we can compare with the // max exponent (cannot compare with MaxExponentTimes2 directly because it is // negative and non-negative floats would be greater). const VFromD exp = BitCast(di, ShiftRight() + 1>(Add(vu, vu))); return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); } // ================================================== SWIZZLE // ------------------------------ LowerHalf // <= 64 bit: just return different type template HWY_API Vec128 LowerHalf(const Vec128 v) { return Vec128(v.raw); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_u8(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_u16(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_u32(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_u64(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_s8(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_s16(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_s32(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_s64(v.raw)); } HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_f32(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_API Vec64 LowerHalf(const Vec128 v) { return Vec64(vget_low_f64(v.raw)); } #endif HWY_API Vec64 LowerHalf(const Vec128 v) { const Full128 du; const Full64 dbh; return BitCast(dbh, LowerHalf(BitCast(du, v))); } template HWY_API Vec128 LowerHalf(Simd /* tag */, Vec128 v) { return LowerHalf(v); } // ------------------------------ CombineShiftRightBytes // 128-bit template > HWY_API V128 CombineShiftRightBytes(Full128 d, V128 hi, V128 lo) { static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); const Repartition d8; uint8x16_t v8 = vextq_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); return BitCast(d, Vec128(v8)); } // 64-bit template HWY_API Vec64 CombineShiftRightBytes(Full64 d, Vec64 hi, Vec64 lo) { static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); const Repartition d8; uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); return BitCast(d, VFromD(v8)); } // <= 32-bit defined after ShiftLeftBytes. // ------------------------------ Shift vector by constant #bytes namespace detail { // Partially specialize because kBytes = 0 and >= size are compile errors; // callers replace the latter with 0xFF for easier specialization. template struct ShiftLeftBytesT { // Full template HWY_INLINE Vec128 operator()(const Vec128 v) { const Full128 d; return CombineShiftRightBytes<16 - kBytes>(d, v, Zero(d)); } // Partial template HWY_INLINE Vec128 operator()(const Vec128 v) { // Expand to 64-bit so we only use the native EXT instruction. const Full64 d64; const auto zero64 = Zero(d64); const decltype(zero64) v64(v.raw); return Vec128( CombineShiftRightBytes<8 - kBytes>(d64, v64, zero64).raw); } }; template <> struct ShiftLeftBytesT<0> { template HWY_INLINE Vec128 operator()(const Vec128 v) { return v; } }; template <> struct ShiftLeftBytesT<0xFF> { template HWY_INLINE Vec128 operator()(const Vec128 /* v */) { return Zero(Simd()); } }; template struct ShiftRightBytesT { template HWY_INLINE Vec128 operator()(Vec128 v) { const Simd d; // For < 64-bit vectors, zero undefined lanes so we shift in zeros. if (N * sizeof(T) < 8) { constexpr size_t kReg = N * sizeof(T) == 16 ? 16 : 8; const Simd dreg; v = Vec128( IfThenElseZero(FirstN(dreg, N), VFromD(v.raw)).raw); } return CombineShiftRightBytes(d, Zero(d), v); } }; template <> struct ShiftRightBytesT<0> { template HWY_INLINE Vec128 operator()(const Vec128 v) { return v; } }; template <> struct ShiftRightBytesT<0xFF> { template HWY_INLINE Vec128 operator()(const Vec128 /* v */) { return Zero(Simd()); } }; } // namespace detail template HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { return detail::ShiftLeftBytesT < kBytes >= N * sizeof(T) ? 0xFF : kBytes > ()(v); } template HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { return ShiftLeftBytes(Simd(), v); } template HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { return ShiftLeftLanes(Simd(), v); } // 0x01..0F, kBytes = 1 => 0x0001..0E template HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { return detail::ShiftRightBytesT < kBytes >= N * sizeof(T) ? 0xFF : kBytes > ()(v); } template HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // Calls ShiftLeftBytes template HWY_API Vec128 CombineShiftRightBytes(Simd d, Vec128 hi, Vec128 lo) { constexpr size_t kSize = N * sizeof(T); static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); const Repartition d8; const Full64 d_full8; const Repartition d_full; using V64 = VFromD; const V64 hi64(BitCast(d8, hi).raw); // Move into most-significant bytes const V64 lo64 = ShiftLeftBytes<8 - kSize>(V64(BitCast(d8, lo).raw)); const V64 r = CombineShiftRightBytes<8 - kSize + kBytes>(d_full8, hi64, lo64); // After casting to full 64-bit vector of correct type, shrink to 32-bit return Vec128(BitCast(d_full, r).raw); } // ------------------------------ UpperHalf (ShiftRightBytes) // Full input HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_u8(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_u16(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_u32(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_u64(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_s8(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_s16(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_s32(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_s64(v.raw)); } HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_f32(v.raw)); } #if HWY_ARCH_ARM_A64 HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { return Vec64(vget_high_f64(v.raw)); } #endif HWY_API Vec64 UpperHalf(Full64 dbh, const Vec128 v) { const RebindToUnsigned duh; const Twice du; return BitCast(dbh, UpperHalf(duh, BitCast(du, v))); } // Partial template HWY_API Vec128 UpperHalf(Half> /* tag */, Vec128 v) { const DFromV d; const RebindToUnsigned du; const auto vu = BitCast(du, v); const auto upper = BitCast(d, ShiftRightBytes(du, vu)); return Vec128(upper.raw); } // ------------------------------ Broadcast/splat any lane #if HWY_ARCH_ARM_A64 // Unsigned template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); return Vec128(vdupq_laneq_u16(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_u16(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec128(vdupq_laneq_u32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_u32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_laneq_u64(v.raw, kLane)); } // Vec64 is defined below. // Signed template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); return Vec128(vdupq_laneq_s16(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_s16(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec128(vdupq_laneq_s32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_s32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_laneq_s64(v.raw, kLane)); } // Vec64 is defined below. // Float template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec128(vdupq_laneq_f32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_f32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_laneq_f64(v.raw, kLane)); } template HWY_API Vec64 Broadcast(const Vec64 v) { static_assert(0 <= kLane && kLane < 1, "Invalid lane"); return v; } #else // No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. // Unsigned template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); return Vec128(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_u16(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec128(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_u32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); } // Vec64 is defined below. // Signed template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); return Vec128(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_s16(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec128(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_s32(v.raw, kLane)); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); } // Vec64 is defined below. // Float template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec128(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); } template HWY_API Vec128 Broadcast(const Vec128 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec128(vdup_lane_f32(v.raw, kLane)); } #endif template HWY_API Vec64 Broadcast(const Vec64 v) { static_assert(0 <= kLane && kLane < 1, "Invalid lane"); return v; } template HWY_API Vec64 Broadcast(const Vec64 v) { static_assert(0 <= kLane && kLane < 1, "Invalid lane"); return v; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices128 { typename detail::Raw128::type raw; }; template HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD const Rebind di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(N))))); #endif const Repartition d8; using V8 = VFromD; const Repartition d16; // Broadcast each lane index to all bytes of T and shift to bytes static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); if (sizeof(T) == 4) { alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; const V8 lane_indices = TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); return Indices128{BitCast(d, sum).raw}; } else { alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; const V8 lane_indices = TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); const V8 byte_indices = BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); return Indices128{BitCast(d, sum).raw}; } } template HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { const DFromV d; const RebindToSigned di; return BitCast( d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128{idx.raw}))); } // ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) // Single lane: no change template HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return v; } // Two lanes: shuffle template HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { return Vec128(Shuffle2301(v)); } template HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { return Shuffle01(v); } // Four lanes: shuffle template HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { return Shuffle0123(v); } // 16-bit template HWY_API Vec128 Reverse(Simd d, const Vec128 v) { const RepartitionToWide> du32; return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); } // ------------------------------ Reverse2 template HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { const RebindToUnsigned du; return BitCast(d, Vec128(vrev32_u16(BitCast(du, v).raw))); } template HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { const RebindToUnsigned du; return BitCast(d, Vec128(vrev32q_u16(BitCast(du, v).raw))); } template HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { const RebindToUnsigned du; return BitCast(d, Vec128(vrev64_u32(BitCast(du, v).raw))); } template HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { const RebindToUnsigned du; return BitCast(d, Vec128(vrev64q_u32(BitCast(du, v).raw))); } template HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { return Shuffle01(v); } // ------------------------------ Reverse4 template HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { const RebindToUnsigned du; return BitCast(d, Vec128(vrev64_u16(BitCast(du, v).raw))); } template HWY_API Vec128 Reverse4(Full128 d, const Vec128 v) { const RebindToUnsigned du; return BitCast(d, Vec128(vrev64q_u16(BitCast(du, v).raw))); } template HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { return Shuffle0123(v); } template HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { HWY_ASSERT(0); // don't have 8 u64 lanes } // ------------------------------ Reverse8 template HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { return Reverse(d, v); } template HWY_API Vec128 Reverse8(Simd, const Vec128) { HWY_ASSERT(0); // don't have 8 lanes unless 16-bit } // ------------------------------ Other shuffles (TableLookupBytes) // Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). // Shuffle0321 rotates one lane to the right (the previous least-significant // lane is now most-significant). These could also be implemented via // CombineShiftRightBytes but the shuffle_abcd notation is more convenient. // Swap 64-bit halves template HWY_API Vec128 Shuffle1032(const Vec128 v) { return CombineShiftRightBytes<8>(Full128(), v, v); } template HWY_API Vec128 Shuffle01(const Vec128 v) { return CombineShiftRightBytes<8>(Full128(), v, v); } // Rotate right 32 bits template HWY_API Vec128 Shuffle0321(const Vec128 v) { return CombineShiftRightBytes<4>(Full128(), v, v); } // Rotate left 32 bits template HWY_API Vec128 Shuffle2103(const Vec128 v) { return CombineShiftRightBytes<12>(Full128(), v, v); } // Reverse template HWY_API Vec128 Shuffle0123(const Vec128 v) { return Shuffle2301(Shuffle1032(v)); } // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides // the least-significant lane) and "b". To concatenate two half-width integers // into one, use ZipLower/Upper instead (also works with scalar). HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveLower, vzip1, _, 2) HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveLower, vzip1, _, 2) #if HWY_ARCH_ARM_A64 // N=1 makes no sense (in that case, there would be no upper/lower). HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return Vec128(vzip1q_u64(a.raw, b.raw)); } HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return Vec128(vzip1q_s64(a.raw, b.raw)); } HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return Vec128(vzip1q_f64(a.raw, b.raw)); } #else // ARMv7 emulation. HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return CombineShiftRightBytes<8>(Full128(), b, Shuffle01(a)); } HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return CombineShiftRightBytes<8>(Full128(), b, Shuffle01(a)); } #endif // Floats HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return Vec128(vzip1q_f32(a.raw, b.raw)); } template HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { return Vec128(vzip1_f32(a.raw, b.raw)); } // < 64 bit parts template HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { return Vec128(InterleaveLower(Vec64(a.raw), Vec64(b.raw)).raw); } // Additional overload for the optional Simd<> tag. template > HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { return InterleaveLower(a, b); } // ------------------------------ InterleaveUpper (UpperHalf) // All functions inside detail lack the required D parameter. namespace detail { HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveUpper, vzip2, _, 2) HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveUpper, vzip2, _, 2) #if HWY_ARCH_ARM_A64 // N=1 makes no sense (in that case, there would be no upper/lower). HWY_API Vec128 InterleaveUpper(const Vec128 a, const Vec128 b) { return Vec128(vzip2q_u64(a.raw, b.raw)); } HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { return Vec128(vzip2q_s64(a.raw, b.raw)); } HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { return Vec128(vzip2q_f64(a.raw, b.raw)); } #else // ARMv7 emulation. HWY_API Vec128 InterleaveUpper(const Vec128 a, const Vec128 b) { return CombineShiftRightBytes<8>(Full128(), Shuffle01(b), a); } HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { return CombineShiftRightBytes<8>(Full128(), Shuffle01(b), a); } #endif HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { return Vec128(vzip2q_f32(a.raw, b.raw)); } HWY_API Vec64 InterleaveUpper(const Vec64 a, const Vec64 b) { return Vec64(vzip2_f32(a.raw, b.raw)); } } // namespace detail // Full register template > HWY_API V InterleaveUpper(Simd /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // Partial template > HWY_API V InterleaveUpper(Simd d, V a, V b) { const Half d2; return InterleaveLower(d, V(UpperHalf(d2, a).raw), V(UpperHalf(d2, b).raw)); } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template >> HWY_API VFromD ZipLower(V a, V b) { return BitCast(DW(), InterleaveLower(a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipLower(DW dw, V a, V b) { return BitCast(dw, InterleaveLower(D(), a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipUpper(DW dw, V a, V b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) template HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, Vec128 a, Vec128 b, const Vec128 sum0, Vec128& sum1) { const Rebind du32; using VU32 = VFromD; const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 // Avoid ZipLower/Upper so this also works on big-endian systems. const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); const VU32 ao = And(BitCast(du32, a), odd); const VU32 be = ShiftLeft<16>(BitCast(du32, b)); const VU32 bo = And(BitCast(du32, b), odd); sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); } HWY_API Vec128 ReorderWidenMulAccumulate(Full128 /*d32*/, Vec128 a, Vec128 b, const Vec128 sum0, Vec128& sum1) { #if HWY_ARCH_ARM_A64 sum1 = Vec128(vmlal_high_s16(sum1.raw, a.raw, b.raw)); #else const Full64 dh; sum1 = Vec128( vmlal_s16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); #endif return Vec128( vmlal_s16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); } HWY_API Vec64 ReorderWidenMulAccumulate(Full64 d32, Vec64 a, Vec64 b, const Vec64 sum0, Vec64& sum1) { // vmlal writes into the upper half, which the caller cannot use, so // split into two halves. const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); const Vec64 mul_32 = UpperHalf(d32, mul_3210); sum1 += mul_32; return sum0 + LowerHalf(mul_3210); } HWY_API Vec32 ReorderWidenMulAccumulate(Full32 d32, Vec32 a, Vec32 b, const Vec32 sum0, Vec32& sum1) { const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); const Vec64 mul_10(LowerHalf(mul_xx10)); const Vec32 mul0 = LowerHalf(d32, mul_10); const Vec32 mul1 = UpperHalf(d32, mul_10); sum1 += mul1; return sum0 + mul0; } // ================================================== COMBINE // ------------------------------ Combine (InterleaveLower) // Full result HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_u8(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_u16(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_u32(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_u64(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_s8(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_s16(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_s32(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_s64(lo.raw, hi.raw)); } HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_f32(lo.raw, hi.raw)); } #if HWY_ARCH_ARM_A64 HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, Vec64 lo) { return Vec128(vcombine_f64(lo.raw, hi.raw)); } #endif // < 64bit input, <= 64 bit result template HWY_API Vec128 Combine(Simd d, Vec128 hi, Vec128 lo) { // First double N (only lower halves will be used). const Vec128 hi2(hi.raw); const Vec128 lo2(lo.raw); // Repartition to two unsigned lanes (each the size of the valid input). const Simd, 2, 0> du; return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); } // ------------------------------ RearrangeToOddPlusEven (Combine) template HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, const Vec128 sum1) { return Add(sum0, sum1); } HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, const Vec128 sum1) { // vmlal_s16 multiplied the lower half into sum0 and upper into sum1. #if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want return Vec128(vpaddq_s32(sum0.raw, sum1.raw)); #else const Full128 d; const Half d64; const Vec64 hi( vpadd_s32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw)); const Vec64 lo( vpadd_s32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw)); return Combine(Full128(), hi, lo); #endif } HWY_API Vec64 RearrangeToOddPlusEven(const Vec64 sum0, const Vec64 sum1) { // vmlal_s16 multiplied the lower half into sum0 and upper into sum1. return Vec64(vpadd_s32(sum0.raw, sum1.raw)); } HWY_API Vec32 RearrangeToOddPlusEven(const Vec32 sum0, const Vec32 sum1) { // Only one widened sum per register, so add them for sum of odd and even. return sum0 + sum1; } // ------------------------------ ZeroExtendVector (Combine) template HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { return Combine(d, Zero(Half()), lo); } // ------------------------------ ConcatLowerLower // 64 or 128-bit input: just interleave template HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as a single lane and interleave them. const Repartition, decltype(d)> du; return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); } namespace detail { #if HWY_ARCH_ARM_A64 HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveEven, vtrn1, _, 2) HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveOdd, vtrn2, _, 2) #else // vtrn returns a struct with even and odd result. #define HWY_NEON_BUILD_TPL_HWY_TRN #define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t // Pass raw args so we can accept uint16x2 args, for which there is no // corresponding uint16x2x2 return type. #define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ Raw128::type a, Raw128::type b #define HWY_NEON_BUILD_ARG_HWY_TRN a, b // Cannot use UINT8 etc. type macros because the x2_t tuples are only defined // for full and half vectors. HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) #endif } // namespace detail // <= 32-bit input/output template HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as two lanes and take every second one. const Repartition, decltype(d)> du; #if HWY_ARCH_ARM_A64 return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); #else using VU = VFromD; return BitCast( d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) .val[0])); #endif } // ------------------------------ ConcatUpperUpper // 64 or 128-bit input: just interleave template HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as a single lane and interleave them. const Repartition, decltype(d)> du; return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); } // <= 32-bit input/output template HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, Vec128 lo) { // Treat half-width input as two lanes and take every second one. const Repartition, decltype(d)> du; #if HWY_ARCH_ARM_A64 return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); #else using VU = VFromD; return BitCast( d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) .val[1])); #endif } // ------------------------------ ConcatLowerUpper (ShiftLeftBytes) // 64 or 128-bit input: extract from concatenated template HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, Vec128 lo) { return CombineShiftRightBytes(d, hi, lo); } // <= 32-bit input/output template HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, Vec128 lo) { constexpr size_t kSize = N * sizeof(T); const Repartition d8; const Full64 d8x8; const Full64 d64; using V8x8 = VFromD; const V8x8 hi8x8(BitCast(d8, hi).raw); // Move into most-significant bytes const V8x8 lo8x8 = ShiftLeftBytes<8 - kSize>(V8x8(BitCast(d8, lo).raw)); const V8x8 r = CombineShiftRightBytes<8 - kSize / 2>(d8x8, hi8x8, lo8x8); // Back to original lane type, then shrink N. return Vec128(BitCast(d64, r).raw); } // ------------------------------ ConcatUpperLower // Works for all N. template HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, Vec128 lo) { return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); } // ------------------------------ ConcatOdd (InterleaveUpper) namespace detail { // There is no vuzpq_u64. HWY_NEON_DEF_FUNCTION_UIF81632(ConcatEven, vuzp1, _, 2) HWY_NEON_DEF_FUNCTION_UIF81632(ConcatOdd, vuzp2, _, 2) } // namespace detail // Full/half vector template = 8>* = nullptr> HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, Vec128 lo) { return detail::ConcatOdd(lo, hi); } // 8-bit x4 template HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, Vec128 lo) { const Twice d2; const Repartition dw2; const VFromD hi2(hi.raw); const VFromD lo2(lo.raw); const VFromD Hx1Lx1 = BitCast(dw2, ConcatOdd(d2, hi2, lo2)); // Compact into two pairs of u8, skipping the invalid x lanes. Could also use // vcopy_lane_u16, but that's A64-only. return Vec128(BitCast(d2, ConcatEven(dw2, Hx1Lx1, Hx1Lx1)).raw); } // Any type x2 template HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, Vec128 lo) { return InterleaveUpper(d, lo, hi); } // ------------------------------ ConcatEven (InterleaveLower) // Full/half vector template = 8>* = nullptr> HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, Vec128 lo) { return detail::ConcatEven(lo, hi); } // 8-bit x4 template HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, Vec128 lo) { const Twice d2; const Repartition dw2; const VFromD hi2(hi.raw); const VFromD lo2(lo.raw); const VFromD Hx0Lx0 = BitCast(dw2, ConcatEven(d2, hi2, lo2)); // Compact into two pairs of u8, skipping the invalid x lanes. Could also use // vcopy_lane_u16, but that's A64-only. return Vec128(BitCast(d2, ConcatEven(dw2, Hx0Lx0, Hx0Lx0)).raw); } // Any type x2 template HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, Vec128 lo) { return InterleaveLower(d, lo, hi); } // ------------------------------ DupEven (InterleaveLower) template HWY_API Vec128 DupEven(Vec128 v) { #if HWY_ARCH_ARM_A64 return detail::InterleaveEven(v, v); #else return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); #endif } template HWY_API Vec128 DupEven(const Vec128 v) { return InterleaveLower(Simd(), v, v); } // ------------------------------ DupOdd (InterleaveUpper) template HWY_API Vec128 DupOdd(Vec128 v) { #if HWY_ARCH_ARM_A64 return detail::InterleaveOdd(v, v); #else return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); #endif } template HWY_API Vec128 DupOdd(const Vec128 v) { return InterleaveUpper(Simd(), v, v); } // ------------------------------ OddEven (IfThenElse) template HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { const Simd d; const Repartition d8; alignas(16) constexpr uint8_t kBytes[16] = { ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, }; const auto vec = BitCast(d, Load(d8, kBytes)); return IfThenElse(MaskFromVec(vec), b, a); } // ------------------------------ OddEvenBlocks template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { return even; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { return v; } // ------------------------------ ReverseBlocks // Single block: no change template HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { return v; } // ------------------------------ ReorderDemote2To (OddEven) template HWY_API Vec128 ReorderDemote2To( Simd dbf16, Vec128 a, Vec128 b) { const RebindToUnsigned du16; const Repartition du32; const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } HWY_API Vec128 ReorderDemote2To(Full128 d16, Vec128 a, Vec128 b) { const Vec64 a16(vqmovn_s32(a.raw)); #if HWY_ARCH_ARM_A64 (void)d16; return Vec128(vqmovn_high_s32(a16.raw, b.raw)); #else const Vec64 b16(vqmovn_s32(b.raw)); return Combine(d16, a16, b16); #endif } HWY_API Vec64 ReorderDemote2To(Full64 /*d16*/, Vec64 a, Vec64 b) { const Full128 d32; const Vec128 ab = Combine(d32, a, b); return Vec64(vqmovn_s32(ab.raw)); } HWY_API Vec32 ReorderDemote2To(Full32 /*d16*/, Vec32 a, Vec32 b) { const Full128 d32; const Vec64 ab(vzip1_s32(a.raw, b.raw)); return Vec32(vqmovn_s32(Combine(d32, ab, ab).raw)); } // ================================================== CRYPTO #if defined(__ARM_FEATURE_AES) || \ (HWY_HAVE_RUNTIME_DISPATCH && HWY_ARCH_ARM_A64) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES #undef HWY_NATIVE_AES #else #define HWY_NATIVE_AES #endif HWY_API Vec128 AESRound(Vec128 state, Vec128 round_key) { // NOTE: it is important that AESE and AESMC be consecutive instructions so // they can be fused. AESE includes AddRoundKey, which is a different ordering // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual // round key (the compiler will hopefully optimize this for multiple rounds). return Vec128(vaesmcq_u8(vaeseq_u8(state.raw, vdupq_n_u8(0)))) ^ round_key; } HWY_API Vec128 AESLastRound(Vec128 state, Vec128 round_key) { return Vec128(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; } HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { return Vec128((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); } HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { return Vec128( (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); } #endif // __ARM_FEATURE_AES // ================================================== MISC template HWY_API Vec128 PromoteTo(Simd df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } // ------------------------------ Truncations template * = nullptr> HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); return Vec128{v1.raw}; } HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); const auto v2 = detail::ConcatEven(v1, v1); const auto v3 = detail::ConcatEven(v2, v2); const auto v4 = detail::ConcatEven(v3, v3); return LowerHalf(LowerHalf(LowerHalf(v4))); } HWY_API Vec32 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); const auto v2 = detail::ConcatEven(v1, v1); const auto v3 = detail::ConcatEven(v2, v2); return LowerHalf(LowerHalf(v3)); } HWY_API Vec64 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); const auto v2 = detail::ConcatEven(v1, v1); return LowerHalf(v2); } template = 2>* = nullptr> HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); const auto v2 = detail::ConcatEven(v1, v1); const auto v3 = detail::ConcatEven(v2, v2); return LowerHalf(LowerHalf(v3)); } template = 2>* = nullptr> HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); const auto v2 = detail::ConcatEven(v1, v1); return LowerHalf(v2); } template = 2>* = nullptr> HWY_API Vec128 TruncateTo(Simd /* tag */, const Vec128 v) { const Repartition> d; const auto v1 = BitCast(d, v); const auto v2 = detail::ConcatEven(v1, v1); return LowerHalf(v2); } // ------------------------------ MulEven (ConcatEven) // Multiplies even lanes (0, 2 ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { const Full128 d; int32x4_t a_packed = ConcatEven(d, a, a).raw; int32x4_t b_packed = ConcatEven(d, b, b).raw; return Vec128( vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); } HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { const Full128 d; uint32x4_t a_packed = ConcatEven(d, a, a).raw; uint32x4_t b_packed = ConcatEven(d, b, b).raw; return Vec128( vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); } template HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { const DFromV d; int32x2_t a_packed = ConcatEven(d, a, a).raw; int32x2_t b_packed = ConcatEven(d, b, b).raw; return Vec128( vget_low_s64(vmull_s32(a_packed, b_packed))); } template HWY_API Vec128 MulEven(const Vec128 a, const Vec128 b) { const DFromV d; uint32x2_t a_packed = ConcatEven(d, a, a).raw; uint32x2_t b_packed = ConcatEven(d, b, b).raw; return Vec128( vget_low_u64(vmull_u32(a_packed, b_packed))); } HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { uint64_t hi; uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 0), vgetq_lane_u64(b.raw, 0), &hi); return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); } HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { uint64_t hi; uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 1), vgetq_lane_u64(b.raw, 1), &hi); return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); } // ------------------------------ TableLookupBytes (Combine, LowerHalf) // Both full template HWY_API Vec128 TableLookupBytes(const Vec128 bytes, const Vec128 from) { const Full128 d; const Repartition d8; #if HWY_ARCH_ARM_A64 return BitCast(d, Vec128(vqtbl1q_u8(BitCast(d8, bytes).raw, BitCast(d8, from).raw))); #else uint8x16_t table0 = BitCast(d8, bytes).raw; uint8x8x2_t table; table.val[0] = vget_low_u8(table0); table.val[1] = vget_high_u8(table0); uint8x16_t idx = BitCast(d8, from).raw; uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); return BitCast(d, Vec128(vcombine_u8(low, hi))); #endif } // Partial index vector template HWY_API Vec128 TableLookupBytes(const Vec128 bytes, const Vec128 from) { const Full128 d_full; const Vec64 from64(from.raw); const auto idx_full = Combine(d_full, from64, from64); const auto out_full = TableLookupBytes(bytes, idx_full); return Vec128(LowerHalf(Half(), out_full).raw); } // Partial table vector template HWY_API Vec128 TableLookupBytes(const Vec128 bytes, const Vec128 from) { const Full128 d_full; return TableLookupBytes(Combine(d_full, bytes, bytes), from); } // Partial both template HWY_API VFromD>> TableLookupBytes( Vec128 bytes, Vec128 from) { const Simd d; const Simd d_idx; const Repartition d_idx8; // uint8x8 const auto bytes8 = BitCast(Repartition(), bytes); const auto from8 = BitCast(d_idx8, from); const VFromD v8(vtbl1_u8(bytes8.raw, from8.raw)); return BitCast(d_idx, v8); } // For all vector widths; ARM anyway zeroes if >= 0x10. template HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { return TableLookupBytes(bytes, from); } // ------------------------------ Scatter (Store) template HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(16) T lanes[N]; Store(v, d, lanes); alignas(16) Offset offset_lanes[N]; Store(offset, Rebind(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); } } template HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(16) T lanes[N]; Store(v, d, lanes); alignas(16) Index index_lanes[N]; Store(index, Rebind(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; } } // ------------------------------ Gather (Load/Store) template HWY_API Vec128 GatherOffset(const Simd d, const T* HWY_RESTRICT base, const Vec128 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(16) Offset offset_lanes[N]; Store(offset, Rebind(), offset_lanes); alignas(16) T lanes[N]; const uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); } return Load(d, lanes); } template HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, const Vec128 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(16) Index index_lanes[N]; Store(index, Rebind(), index_lanes); alignas(16) T lanes[N]; for (size_t i = 0; i < N; ++i) { lanes[i] = base[index_lanes[i]]; } return Load(d, lanes); } // ------------------------------ Reductions namespace detail { // N=1 for any T: no-op template HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, const Vec128 v) { return v; } template HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, const Vec128 v) { return v; } template HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, const Vec128 v) { return v; } // full vectors #if HWY_ARCH_ARM_A64 #define HWY_NEON_BUILD_RET_REDUCTION(type, size) Vec128 #define HWY_NEON_DEF_REDUCTION(type, size, name, prefix, infix, suffix, dup) \ HWY_API HWY_NEON_BUILD_RET_REDUCTION(type, size) \ name(hwy::SizeTag, const Vec128 v) { \ return HWY_NEON_BUILD_RET_REDUCTION( \ type, size)(dup##suffix(HWY_NEON_EVAL(prefix##infix##suffix, v.raw))); \ } #define HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ HWY_NEON_DEF_REDUCTION(uint8, 8, name, prefix, _, u8, vdup_n_) \ HWY_NEON_DEF_REDUCTION(uint8, 16, name, prefix##q, _, u8, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(uint16, 4, name, prefix, _, u16, vdup_n_) \ HWY_NEON_DEF_REDUCTION(uint16, 8, name, prefix##q, _, u16, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(uint32, 2, name, prefix, _, u32, vdup_n_) \ HWY_NEON_DEF_REDUCTION(uint32, 4, name, prefix##q, _, u32, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(int8, 8, name, prefix, _, s8, vdup_n_) \ HWY_NEON_DEF_REDUCTION(int8, 16, name, prefix##q, _, s8, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(int16, 4, name, prefix, _, s16, vdup_n_) \ HWY_NEON_DEF_REDUCTION(int16, 8, name, prefix##q, _, s16, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(int32, 2, name, prefix, _, s32, vdup_n_) \ HWY_NEON_DEF_REDUCTION(int32, 4, name, prefix##q, _, s32, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(float32, 2, name, prefix, _, f32, vdup_n_) \ HWY_NEON_DEF_REDUCTION(float32, 4, name, prefix##q, _, f32, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(float64, 2, name, prefix##q, _, f64, vdupq_n_) HWY_NEON_DEF_REDUCTION_CORE_TYPES(MinOfLanes, vminv) HWY_NEON_DEF_REDUCTION_CORE_TYPES(MaxOfLanes, vmaxv) // u64/s64 don't have horizontal min/max for some reason, but do have add. #define HWY_NEON_DEF_REDUCTION_ALL_TYPES(name, prefix) \ HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ HWY_NEON_DEF_REDUCTION(uint64, 2, name, prefix##q, _, u64, vdupq_n_) \ HWY_NEON_DEF_REDUCTION(int64, 2, name, prefix##q, _, s64, vdupq_n_) HWY_NEON_DEF_REDUCTION_ALL_TYPES(SumOfLanes, vaddv) #undef HWY_NEON_DEF_REDUCTION_ALL_TYPES #undef HWY_NEON_DEF_REDUCTION_CORE_TYPES #undef HWY_NEON_DEF_REDUCTION #undef HWY_NEON_BUILD_RET_REDUCTION // Need some fallback implementations for [ui]64x2 and [ui]16x2. #define HWY_IF_SUM_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 1 << 2) #define HWY_IF_MINMAX_REDUCTION(T) \ HWY_IF_LANE_SIZE_ONE_OF(T, (1 << 8) | (1 << 2)) #else // u32/i32/f32: N=2 template HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v10) { return v10 + Shuffle2301(v10); } template HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v10) { return Min(v10, Shuffle2301(v10)); } template HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v10) { return Max(v10, Shuffle2301(v10)); } // ARMv7 version for everything except doubles. HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v) { uint32x4x2_t v0 = vuzpq_u32(v.raw, v.raw); uint32x4_t c0 = vaddq_u32(v0.val[0], v0.val[1]); uint32x4x2_t v1 = vuzpq_u32(c0, c0); return Vec128(vaddq_u32(v1.val[0], v1.val[1])); } HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v) { int32x4x2_t v0 = vuzpq_s32(v.raw, v.raw); int32x4_t c0 = vaddq_s32(v0.val[0], v0.val[1]); int32x4x2_t v1 = vuzpq_s32(c0, c0); return Vec128(vaddq_s32(v1.val[0], v1.val[1])); } HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v) { float32x4x2_t v0 = vuzpq_f32(v.raw, v.raw); float32x4_t c0 = vaddq_f32(v0.val[0], v0.val[1]); float32x4x2_t v1 = vuzpq_f32(c0, c0); return Vec128(vaddq_f32(v1.val[0], v1.val[1])); } HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, const Vec128 v) { return v + Shuffle01(v); } HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, const Vec128 v) { return v + Shuffle01(v); } template HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v3210) { const Vec128 v1032 = Shuffle1032(v3210); const Vec128 v31_20_31_20 = Min(v3210, v1032); const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); return Min(v20_31_20_31, v31_20_31_20); } template HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, const Vec128 v3210) { const Vec128 v1032 = Shuffle1032(v3210); const Vec128 v31_20_31_20 = Max(v3210, v1032); const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); return Max(v20_31_20_31, v31_20_31_20); } #define HWY_NEON_BUILD_TYPE_T(type, size) type##x##size##_t #define HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) Vec128 #define HWY_NEON_DEF_PAIRWISE_REDUCTION(type, size, name, prefix, suffix) \ HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) \ name(hwy::SizeTag, const Vec128 v) { \ HWY_NEON_BUILD_TYPE_T(type, size) tmp = prefix##_##suffix(v.raw, v.raw); \ if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ type, size)(HWY_NEON_EVAL(vdup##_lane_##suffix, tmp, 0)); \ } #define HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(type, size, half, name, prefix, \ suffix) \ HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) \ name(hwy::SizeTag, const Vec128 v) { \ HWY_NEON_BUILD_TYPE_T(type, half) tmp; \ tmp = prefix##_##suffix(vget_high_##suffix(v.raw), \ vget_low_##suffix(v.raw)); \ if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ if ((size / 8) > 1) tmp = prefix##_##suffix(tmp, tmp); \ tmp = vdup_lane_##suffix(tmp, 0); \ return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ type, size)(HWY_NEON_EVAL(vcombine_##suffix, tmp, tmp)); \ } #define HWY_NEON_DEF_PAIRWISE_REDUCTIONS(name, prefix) \ HWY_NEON_DEF_PAIRWISE_REDUCTION(uint16, 4, name, prefix, u16) \ HWY_NEON_DEF_PAIRWISE_REDUCTION(uint8, 8, name, prefix, u8) \ HWY_NEON_DEF_PAIRWISE_REDUCTION(int16, 4, name, prefix, s16) \ HWY_NEON_DEF_PAIRWISE_REDUCTION(int8, 8, name, prefix, s8) \ HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint16, 8, 4, name, prefix, u16) \ HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint8, 16, 8, name, prefix, u8) \ HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int16, 8, 4, name, prefix, s16) \ HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int8, 16, 8, name, prefix, s8) HWY_NEON_DEF_PAIRWISE_REDUCTIONS(SumOfLanes, vpadd) HWY_NEON_DEF_PAIRWISE_REDUCTIONS(MinOfLanes, vpmin) HWY_NEON_DEF_PAIRWISE_REDUCTIONS(MaxOfLanes, vpmax) #undef HWY_NEON_DEF_PAIRWISE_REDUCTIONS #undef HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION #undef HWY_NEON_DEF_PAIRWISE_REDUCTION #undef HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION #undef HWY_NEON_BUILD_TYPE_T template HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { const Simd d; const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); } template HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { const Simd d; const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); } template HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { const Simd d; const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } template HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { const Simd d; const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } template HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { const Simd d; const RepartitionToWide d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } template HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec128 v) { const Simd d; const RepartitionToWide d32; // Sign-extend const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); // Also broadcast into odd lanes. return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); } // Need fallback min/max implementations for [ui]64x2. #define HWY_IF_SUM_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 0) #define HWY_IF_MINMAX_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 1 << 8) #endif // [ui]16/[ui]64: N=2 -- special case for pairs of very small or large lanes template HWY_API Vec128 SumOfLanes(hwy::SizeTag /* tag */, const Vec128 v10) { return v10 + Reverse2(Simd(), v10); } template HWY_API Vec128 MinOfLanes(hwy::SizeTag /* tag */, const Vec128 v10) { return Min(v10, Reverse2(Simd(), v10)); } template HWY_API Vec128 MaxOfLanes(hwy::SizeTag /* tag */, const Vec128 v10) { return Max(v10, Reverse2(Simd(), v10)); } #undef HWY_IF_SUM_REDUCTION #undef HWY_IF_MINMAX_REDUCTION } // namespace detail template HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { return detail::SumOfLanes(hwy::SizeTag(), v); } template HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { return detail::MinOfLanes(hwy::SizeTag(), v); } template HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { return detail::MaxOfLanes(hwy::SizeTag(), v); } // ------------------------------ LoadMaskBits (TestBit) namespace detail { // Helper function to set 64 bits and potentially return a smaller vector. The // overload is required to call the q vs non-q intrinsics. Note that 8-bit // LoadMaskBits only requires 16 bits, but 64 avoids casting. template HWY_INLINE Vec128 Set64(Simd /* tag */, uint64_t mask_bits) { const auto v64 = Vec64(vdup_n_u64(mask_bits)); return Vec128(BitCast(Full64(), v64).raw); } template HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { return BitCast(d, Vec128(vdupq_n_u64(mask_bits))); } template HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; // Easier than Set(), which would require an >8-bit type, which would not // compile for T=uint8_t, N=1. const auto vmask_bits = Set64(du, mask_bits); // Replicate bytes 8x such that each byte contains the bit that governs it. alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}; const auto rep8 = TableLookupBytes(vmask_bits, Load(du, kRep8)); alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); } template HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; const auto vmask_bits = Set(du, static_cast(mask_bits)); return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); } template HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; const auto vmask_bits = Set(du, static_cast(mask_bits)); return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); } template HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(16) constexpr uint64_t kBit[8] = {1, 2}; return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); } } // namespace detail // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask128 LoadMaskBits(Simd d, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; CopyBytes<(N + 7) / 8>(bits, &mask_bits); return detail::LoadMaskBits(d, mask_bits); } // ------------------------------ Mask namespace detail { // Returns mask[i]? 0xF : 0 in each nibble. This is more efficient than // BitsFromMask for use in (partial) CountTrue, FindFirstTrue and AllFalse. template HWY_INLINE uint64_t NibblesFromMask(const Full128 d, Mask128 mask) { const Full128 du16; const Vec128 vu16 = BitCast(du16, VecFromMask(d, mask)); const Vec64 nib(vshrn_n_u16(vu16.raw, 4)); return GetLane(BitCast(Full64(), nib)); } template HWY_INLINE uint64_t NibblesFromMask(const Full64 d, Mask64 mask) { // There is no vshrn_n_u16 for uint16x4, so zero-extend. const Twice d2; const Vec128 v128 = ZeroExtendVector(d2, VecFromMask(d, mask)); // No need to mask, upper half is zero thanks to ZeroExtendVector. return NibblesFromMask(d2, MaskFromVec(v128)); } template HWY_INLINE uint64_t NibblesFromMask(Simd /*d*/, Mask128 mask) { const Mask64 mask64(mask.raw); const uint64_t nib = NibblesFromMask(Full64(), mask64); // Clear nibbles from upper half of 64-bits constexpr size_t kBytes = sizeof(T) * N; return nib & ((1ull << (kBytes * 4)) - 1); } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { alignas(16) constexpr uint8_t kSliceLanes[16] = { 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, }; const Full128 du; const Vec128 values = BitCast(du, VecFromMask(Full128(), mask)) & Load(du, kSliceLanes); #if HWY_ARCH_ARM_A64 // Can't vaddv - we need two separate bytes (16 bits). const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); const uint8x8_t x4 = vpadd_u8(x2, x2); const uint8x8_t x8 = vpadd_u8(x4, x4); return vget_lane_u64(vreinterpret_u64_u8(x8), 0); #else // Don't have vpaddq, so keep doubling lane size. const uint16x8_t x2 = vpaddlq_u8(values.raw); const uint32x4_t x4 = vpaddlq_u16(x2); const uint64x2_t x8 = vpaddlq_u32(x4); return (vgetq_lane_u64(x8, 1) << 8) | vgetq_lane_u64(x8, 0); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { // Upper lanes of partial loads are undefined. OnlyActive will fix this if // we load all kSliceLanes so the upper lanes do not pollute the valid bits. alignas(8) constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80}; const Simd d; const RebindToUnsigned du; const Vec128 slice(Load(Full64(), kSliceLanes).raw); const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; #if HWY_ARCH_ARM_A64 return vaddv_u8(values.raw); #else const uint16x4_t x2 = vpaddl_u8(values.raw); const uint32x2_t x4 = vpaddl_u16(x2); const uint64x1_t x8 = vpaddl_u32(x4); return vget_lane_u64(x8, 0); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, const Mask128 mask) { alignas(16) constexpr uint16_t kSliceLanes[8] = {1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80}; const Full128 d; const Full128 du; const Vec128 values = BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); #if HWY_ARCH_ARM_A64 return vaddvq_u16(values.raw); #else const uint32x4_t x2 = vpaddlq_u16(values.raw); const uint64x2_t x4 = vpaddlq_u32(x2); return vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, const Mask128 mask) { // Upper lanes of partial loads are undefined. OnlyActive will fix this if // we load all kSliceLanes so the upper lanes do not pollute the valid bits. alignas(8) constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; const Simd d; const RebindToUnsigned du; const Vec128 slice(Load(Full64(), kSliceLanes).raw); const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; #if HWY_ARCH_ARM_A64 return vaddv_u16(values.raw); #else const uint32x2_t x2 = vpaddl_u16(values.raw); const uint64x1_t x4 = vpaddl_u32(x2); return vget_lane_u64(x4, 0); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { alignas(16) constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; const Full128 d; const Full128 du; const Vec128 values = BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); #if HWY_ARCH_ARM_A64 return vaddvq_u32(values.raw); #else const uint64x2_t x2 = vpaddlq_u32(values.raw); return vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { // Upper lanes of partial loads are undefined. OnlyActive will fix this if // we load all kSliceLanes so the upper lanes do not pollute the valid bits. alignas(8) constexpr uint32_t kSliceLanes[2] = {1, 2}; const Simd d; const RebindToUnsigned du; const Vec128 slice(Load(Full64(), kSliceLanes).raw); const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; #if HWY_ARCH_ARM_A64 return vaddv_u32(values.raw); #else const uint64x1_t x2 = vpaddl_u32(values.raw); return vget_lane_u64(x2, 0); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { alignas(16) constexpr uint64_t kSliceLanes[2] = {1, 2}; const Full128 d; const Full128 du; const Vec128 values = BitCast(du, VecFromMask(d, m)) & Load(du, kSliceLanes); #if HWY_ARCH_ARM_A64 return vaddvq_u64(values.raw); #else return vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1); #endif } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { const Full64 d; const Full64 du; const Vec64 values = BitCast(du, VecFromMask(d, m)) & Set(du, 1); return vget_lane_u64(values.raw, 0); } // Returns the lowest N for the BitsFromMask result. template constexpr uint64_t OnlyActive(uint64_t bits) { return ((N * sizeof(T)) >= 8) ? bits : (bits & ((1ull << N) - 1)); } template HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); } // Returns number of lanes whose mask is set. // // Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op // ("vsubv"). ANDing with 1 would work but requires a constant. Negating also // changes each lane to 1 (if mask set) or 0. // NOTE: PopCount also operates on vectors, so we still have to do horizontal // sums separately. We specialize CountTrue for full vectors (negating instead // of PopCount because it avoids an extra shift), and use PopCount of // NibblesFromMask for partial vectors. template HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { const Full128 di; const int8x16_t ones = vnegq_s8(BitCast(di, VecFromMask(Full128(), mask)).raw); #if HWY_ARCH_ARM_A64 return static_cast(vaddvq_s8(ones)); #else const int16x8_t x2 = vpaddlq_s8(ones); const int32x4_t x4 = vpaddlq_s16(x2); const int64x2_t x8 = vpaddlq_s32(x4); return static_cast(vgetq_lane_s64(x8, 0) + vgetq_lane_s64(x8, 1)); #endif } template HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, const Mask128 mask) { const Full128 di; const int16x8_t ones = vnegq_s16(BitCast(di, VecFromMask(Full128(), mask)).raw); #if HWY_ARCH_ARM_A64 return static_cast(vaddvq_s16(ones)); #else const int32x4_t x2 = vpaddlq_s16(ones); const int64x2_t x4 = vpaddlq_s32(x2); return static_cast(vgetq_lane_s64(x4, 0) + vgetq_lane_s64(x4, 1)); #endif } template HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { const Full128 di; const int32x4_t ones = vnegq_s32(BitCast(di, VecFromMask(Full128(), mask)).raw); #if HWY_ARCH_ARM_A64 return static_cast(vaddvq_s32(ones)); #else const int64x2_t x2 = vpaddlq_s32(ones); return static_cast(vgetq_lane_s64(x2, 0) + vgetq_lane_s64(x2, 1)); #endif } template HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 mask) { #if HWY_ARCH_ARM_A64 const Full128 di; const int64x2_t ones = vnegq_s64(BitCast(di, VecFromMask(Full128(), mask)).raw); return static_cast(vaddvq_s64(ones)); #else const Full128 du; const auto mask_u = VecFromMask(du, RebindMask(du, mask)); const uint64x2_t ones = vshrq_n_u64(mask_u.raw, 63); return static_cast(vgetq_lane_u64(ones, 0) + vgetq_lane_u64(ones, 1)); #endif } } // namespace detail // Full template HWY_API size_t CountTrue(Full128 /* tag */, const Mask128 mask) { return detail::CountTrue(hwy::SizeTag(), mask); } // Partial template HWY_API size_t CountTrue(Simd d, const Mask128 mask) { constexpr int kDiv = 4 * sizeof(T); return PopCount(detail::NibblesFromMask(d, mask)) / kDiv; } template HWY_API size_t FindKnownFirstTrue(const Simd d, const Mask128 mask) { const uint64_t nib = detail::NibblesFromMask(d, mask); constexpr size_t kDiv = 4 * sizeof(T); return Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv; } template HWY_API intptr_t FindFirstTrue(const Simd d, const Mask128 mask) { const uint64_t nib = detail::NibblesFromMask(d, mask); if (nib == 0) return -1; constexpr int kDiv = 4 * sizeof(T); return static_cast(Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv); } // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, uint8_t* bits) { const uint64_t mask_bits = detail::BitsFromMask(mask); const size_t kNumBytes = (N + 7) / 8; CopyBytes(&mask_bits, bits); return kNumBytes; } template HWY_API bool AllFalse(const Simd d, const Mask128 m) { return detail::NibblesFromMask(d, m) == 0; } // Full template HWY_API bool AllTrue(const Full128 d, const Mask128 m) { return detail::NibblesFromMask(d, m) == ~0ull; } // Partial template HWY_API bool AllTrue(const Simd d, const Mask128 m) { constexpr size_t kBytes = sizeof(T) * N; return detail::NibblesFromMask(d, m) == (1ull << (kBytes * 4)) - 1; } // ------------------------------ Compress template struct CompressIsPartition { enum { value = (sizeof(T) != 1) }; }; namespace detail { // Load 8 bytes, replicate into upper half so ZipLower can use the lower half. HWY_INLINE Vec128 Load8Bytes(Full128 /*d*/, const uint8_t* bytes) { return Vec128(vreinterpretq_u8_u64( vld1q_dup_u64(reinterpret_cast(bytes)))); } // Load 8 bytes and return half-reg with N <= 8 bytes. template HWY_INLINE Vec128 Load8Bytes(Simd d, const uint8_t* bytes) { return Load(d, bytes); } template HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); const Simd d; const Repartition d8; const Simd du; // ARM does not provide an equivalent of AVX2 permutevar, so we need byte // indices for VTBL (one vector's worth for each of 256 combinations of // 8 mask bits). Loading them directly would require 4 KiB. We can instead // store lane indices and convert to byte indices (2*lane + 0..1), with the // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles // is likely more costly than the higher cache footprint from storing bytes. alignas(16) constexpr uint8_t table[256 * 8] = { // PrintCompress16x8Tables 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); const Vec128 pairs = ZipLower(byte_idx, byte_idx); return BitCast(d, pairs + Set(du, 0x0100)); } template HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<2> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); const Simd d; const Repartition d8; const Simd du; // ARM does not provide an equivalent of AVX2 permutevar, so we need byte // indices for VTBL (one vector's worth for each of 256 combinations of // 8 mask bits). Loading them directly would require 4 KiB. We can instead // store lane indices and convert to byte indices (2*lane + 0..1), with the // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles // is likely more costly than the higher cache footprint from storing bytes. alignas(16) constexpr uint8_t table[256 * 8] = { // PrintCompressNot16x8Tables 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); const Vec128 pairs = ZipLower(byte_idx, byte_idx); return BitCast(d, pairs + Set(du, 0x0100)); } template HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 16); // There are only 4 lanes, so we can afford to load the index vector directly. alignas(16) constexpr uint8_t u8_indices[16 * 16] = { // PrintCompress32x4Tables 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const Simd d; const Repartition d8; return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); } template HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<4> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 16); // There are only 4 lanes, so we can afford to load the index vector directly. alignas(16) constexpr uint8_t u8_indices[16 * 16] = { // PrintCompressNot32x4Tables 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const Simd d; const Repartition d8; return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); } #if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 template HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 4); // There are only 2 lanes, so we can afford to load the index vector directly. alignas(16) constexpr uint8_t u8_indices[64] = { // PrintCompress64x2Tables 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const Simd d; const Repartition d8; return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); } template HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<8> /*tag*/, const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 4); // There are only 2 lanes, so we can afford to load the index vector directly. alignas(16) constexpr uint8_t u8_indices[4 * 16] = { // PrintCompressNot64x2Tables 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const Simd d; const Repartition d8; return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); } #endif // Helper function called by both Compress and CompressStore - avoids a // redundant BitsFromMask in the latter. template HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { const auto idx = detail::IdxFromBits(hwy::SizeTag(), mask_bits); using D = Simd; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } template HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { const auto idx = detail::IdxFromNotBits(hwy::SizeTag(), mask_bits); using D = Simd; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } } // namespace detail // Single lane: no-op template HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { return v; } // Two lanes: conditional swap template HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. const Simd d; const Vec128 m = VecFromMask(d, mask); const Vec128 maskL = DupEven(m); const Vec128 maskH = DupOdd(m); const Vec128 swap = AndNot(maskL, maskH); return IfVecThenElse(swap, Shuffle01(v), v); } // General case, 2 or 4 byte lanes template HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { return detail::Compress(v, detail::BitsFromMask(mask)); } // Single lane: no-op template HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { return v; } // Two lanes: conditional swap template HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. const Full128 d; const Vec128 m = VecFromMask(d, mask); const Vec128 maskL = DupEven(m); const Vec128 maskH = DupOdd(m); const Vec128 swap = AndNot(maskH, maskL); return IfVecThenElse(swap, Shuffle01(v), v); } // General case, 2 or 4 byte lanes template HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { // For partial vectors, we cannot pull the Not() into the table because // BitsFromMask clears the upper bits. if (N < 16 / sizeof(T)) { return detail::Compress(v, detail::BitsFromMask(Not(mask))); } return detail::CompressNot(v, detail::BitsFromMask(mask)); } // ------------------------------ CompressBlocksNot HWY_API Vec128 CompressBlocksNot(Vec128 v, Mask128 /* m */) { return v; } // ------------------------------ CompressBits template HWY_INLINE Vec128 CompressBits(Vec128 v, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return detail::Compress(v, mask_bits); } // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, Simd d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(mask); StoreU(detail::Compress(v, mask_bits), d, unaligned); return PopCount(mask_bits); } // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, Simd d, T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; // so we can support fp16/bf16 using TU = TFromD; const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); const Mask128 store_mask = RebindMask(d, FirstN(du, count)); const Vec128 compressed = detail::Compress(BitCast(du, v), mask_bits); BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); return count; } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(Vec128 v, const uint8_t* HWY_RESTRICT bits, Simd d, T* HWY_RESTRICT unaligned) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } StoreU(detail::Compress(v, mask_bits), d, unaligned); return PopCount(mask_bits); } // ------------------------------ LoadInterleaved2 // Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. #ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED #undef HWY_NATIVE_LOAD_STORE_INTERLEAVED #else #define HWY_NATIVE_LOAD_STORE_INTERLEAVED #endif namespace detail { #define HWY_NEON_BUILD_TPL_HWY_LOAD_INT #define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from #if HWY_ARCH_ARM_A64 #define HWY_IF_LOAD_INT(T, N) HWY_IF_GE64(T, N) #define HWY_NEON_DEF_FUNCTION_LOAD_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES #else // Exclude 64x2 and f64x1, which are only supported on aarch64 #define HWY_IF_LOAD_INT(T, N) \ hwy::EnableIf= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr #define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) #endif // HWY_ARCH_ARM_A64 // Must return raw tuple because Tuple2 lack a ctor, and we cannot use // brace-initialization in HWY_NEON_DEF_FUNCTION because some functions return // void. #define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ decltype(Tuple2().raw) // Tuple tag arg allows overloading (cannot just overload on return type) #define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ const type##_t *from, Tuple2 HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved2, vld2, _, HWY_LOAD_INT) #undef HWY_NEON_BUILD_RET_HWY_LOAD_INT #undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT #define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ decltype(Tuple3().raw) #define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ const type##_t *from, Tuple3 HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved3, vld3, _, HWY_LOAD_INT) #undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT #undef HWY_NEON_BUILD_RET_HWY_LOAD_INT #define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ decltype(Tuple4().raw) #define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ const type##_t *from, Tuple4 HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) #undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT #undef HWY_NEON_BUILD_RET_HWY_LOAD_INT #undef HWY_NEON_DEF_FUNCTION_LOAD_INT #undef HWY_NEON_BUILD_TPL_HWY_LOAD_INT #undef HWY_NEON_BUILD_ARG_HWY_LOAD_INT } // namespace detail template HWY_API void LoadInterleaved2(Simd /*tag*/, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1) { auto raw = detail::LoadInterleaved2(unaligned, detail::Tuple2()); v0 = Vec128(raw.val[0]); v1 = Vec128(raw.val[1]); } // <= 32 bits: avoid loading more than N bytes by copying to buffer template HWY_API void LoadInterleaved2(Simd /*tag*/, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1) { // The smallest vector registers are 64-bits and we want space for two. alignas(16) T buf[2 * 8 / sizeof(T)] = {}; CopyBytes(unaligned, buf); auto raw = detail::LoadInterleaved2(buf, detail::Tuple2()); v0 = Vec128(raw.val[0]); v1 = Vec128(raw.val[1]); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 template HWY_API void LoadInterleaved2(Full128 d, T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1) { const Half dh; VFromD v00, v10, v01, v11; LoadInterleaved2(dh, unaligned, v00, v10); LoadInterleaved2(dh, unaligned + 2, v01, v11); v0 = Combine(d, v01, v00); v1 = Combine(d, v11, v10); } #endif // HWY_ARCH_ARM_V7 // ------------------------------ LoadInterleaved3 template HWY_API void LoadInterleaved3(Simd /*tag*/, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2) { auto raw = detail::LoadInterleaved3(unaligned, detail::Tuple3()); v0 = Vec128(raw.val[0]); v1 = Vec128(raw.val[1]); v2 = Vec128(raw.val[2]); } // <= 32 bits: avoid writing more than N bytes by copying to buffer template HWY_API void LoadInterleaved3(Simd /*tag*/, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2) { // The smallest vector registers are 64-bits and we want space for three. alignas(16) T buf[3 * 8 / sizeof(T)] = {}; CopyBytes(unaligned, buf); auto raw = detail::LoadInterleaved3(buf, detail::Tuple3()); v0 = Vec128(raw.val[0]); v1 = Vec128(raw.val[1]); v2 = Vec128(raw.val[2]); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 template HWY_API void LoadInterleaved3(Full128 d, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2) { const Half dh; VFromD v00, v10, v20, v01, v11, v21; LoadInterleaved3(dh, unaligned, v00, v10, v20); LoadInterleaved3(dh, unaligned + 3, v01, v11, v21); v0 = Combine(d, v01, v00); v1 = Combine(d, v11, v10); v2 = Combine(d, v21, v20); } #endif // HWY_ARCH_ARM_V7 // ------------------------------ LoadInterleaved4 template HWY_API void LoadInterleaved4(Simd /*tag*/, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2, Vec128& v3) { auto raw = detail::LoadInterleaved4(unaligned, detail::Tuple4()); v0 = Vec128(raw.val[0]); v1 = Vec128(raw.val[1]); v2 = Vec128(raw.val[2]); v3 = Vec128(raw.val[3]); } // <= 32 bits: avoid writing more than N bytes by copying to buffer template HWY_API void LoadInterleaved4(Simd /*tag*/, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2, Vec128& v3) { alignas(16) T buf[4 * 8 / sizeof(T)] = {}; CopyBytes(unaligned, buf); auto raw = detail::LoadInterleaved4(buf, detail::Tuple4()); v0 = Vec128(raw.val[0]); v1 = Vec128(raw.val[1]); v2 = Vec128(raw.val[2]); v3 = Vec128(raw.val[3]); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 template HWY_API void LoadInterleaved4(Full128 d, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2, Vec128& v3) { const Half dh; VFromD v00, v10, v20, v30, v01, v11, v21, v31; LoadInterleaved4(dh, unaligned, v00, v10, v20, v30); LoadInterleaved4(dh, unaligned + 4, v01, v11, v21, v31); v0 = Combine(d, v01, v00); v1 = Combine(d, v11, v10); v2 = Combine(d, v21, v20); v3 = Combine(d, v31, v30); } #endif // HWY_ARCH_ARM_V7 #undef HWY_IF_LOAD_INT // ------------------------------ StoreInterleaved2 namespace detail { #define HWY_NEON_BUILD_TPL_HWY_STORE_INT #define HWY_NEON_BUILD_RET_HWY_STORE_INT(type, size) void #define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw #if HWY_ARCH_ARM_A64 #define HWY_IF_STORE_INT(T, N) HWY_IF_GE64(T, N) #define HWY_NEON_DEF_FUNCTION_STORE_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES #else // Exclude 64x2 and f64x1, which are only supported on aarch64 #define HWY_IF_STORE_INT(T, N) \ hwy::EnableIf= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr #define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) #endif // HWY_ARCH_ARM_A64 #define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ Tuple2 tup, type##_t *to HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved2, vst2, _, HWY_STORE_INT) #undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT #define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ Tuple3 tup, type##_t *to HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved3, vst3, _, HWY_STORE_INT) #undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT #define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ Tuple4 tup, type##_t *to HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved4, vst4, _, HWY_STORE_INT) #undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT #undef HWY_NEON_DEF_FUNCTION_STORE_INT #undef HWY_NEON_BUILD_TPL_HWY_STORE_INT #undef HWY_NEON_BUILD_RET_HWY_STORE_INT #undef HWY_NEON_BUILD_ARG_HWY_STORE_INT } // namespace detail template HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, Simd /*tag*/, T* HWY_RESTRICT unaligned) { detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; detail::StoreInterleaved2(tup, unaligned); } // <= 32 bits: avoid writing more than N bytes by copying to buffer template HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, Simd /*tag*/, T* HWY_RESTRICT unaligned) { alignas(16) T buf[2 * 8 / sizeof(T)]; detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; detail::StoreInterleaved2(tup, buf); CopyBytes(buf, unaligned); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 template HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, Full128 d, T* HWY_RESTRICT unaligned) { const Half dh; StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, unaligned); StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, unaligned + 2); } #endif // HWY_ARCH_ARM_V7 // ------------------------------ StoreInterleaved3 template HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, const Vec128 v2, Simd /*tag*/, T* HWY_RESTRICT unaligned) { detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; detail::StoreInterleaved3(tup, unaligned); } // <= 32 bits: avoid writing more than N bytes by copying to buffer template HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, const Vec128 v2, Simd /*tag*/, T* HWY_RESTRICT unaligned) { alignas(16) T buf[3 * 8 / sizeof(T)]; detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; detail::StoreInterleaved3(tup, buf); CopyBytes(buf, unaligned); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 template HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, const Vec128 v2, Full128 d, T* HWY_RESTRICT unaligned) { const Half dh; StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, unaligned); StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, unaligned + 3); } #endif // HWY_ARCH_ARM_V7 // ------------------------------ StoreInterleaved4 template HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, const Vec128 v2, const Vec128 v3, Simd /*tag*/, T* HWY_RESTRICT unaligned) { detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; detail::StoreInterleaved4(tup, unaligned); } // <= 32 bits: avoid writing more than N bytes by copying to buffer template HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, const Vec128 v2, const Vec128 v3, Simd /*tag*/, T* HWY_RESTRICT unaligned) { alignas(16) T buf[4 * 8 / sizeof(T)]; detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; detail::StoreInterleaved4(tup, buf); CopyBytes(buf, unaligned); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 template HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, const Vec128 v2, const Vec128 v3, Full128 d, T* HWY_RESTRICT unaligned) { const Half dh; StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), LowerHalf(dh, v3), dh, unaligned); StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), UpperHalf(dh, v3), dh, unaligned + 4); } #endif // HWY_ARCH_ARM_V7 #undef HWY_IF_STORE_INT // ------------------------------ Lt128 template HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, Vec128 b) { static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); // Truth table of Eq and Lt for Hi and Lo u64. // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) // =H =L cH cL | out = cH | (=H & cL) // 0 0 0 0 | 0 // 0 0 0 1 | 0 // 0 0 1 0 | 1 // 0 0 1 1 | 1 // 0 1 0 0 | 0 // 0 1 0 1 | 0 // 0 1 1 0 | 1 // 1 0 0 0 | 0 // 1 0 0 1 | 1 // 1 1 0 0 | 0 const Mask128 eqHL = Eq(a, b); const Vec128 ltHL = VecFromMask(d, Lt(a, b)); // We need to bring cL to the upper lane/bit corresponding to cH. Comparing // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the // comparison result leftwards requires only 4. IfThenElse compiles to the // same code as OrAnd(). const Vec128 ltLx = DupEven(ltHL); const Vec128 outHx = IfThenElse(eqHL, ltLx, ltHL); return MaskFromVec(DupOdd(outHx)); } template HWY_INLINE Mask128 Lt128Upper(Simd d, Vec128 a, Vec128 b) { const Vec128 ltHL = VecFromMask(d, Lt(a, b)); return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); } // ------------------------------ Eq128 template HWY_INLINE Mask128 Eq128(Simd d, Vec128 a, Vec128 b) { static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); const Vec128 eqHL = VecFromMask(d, Eq(a, b)); return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); } template HWY_INLINE Mask128 Eq128Upper(Simd d, Vec128 a, Vec128 b) { const Vec128 eqHL = VecFromMask(d, Eq(a, b)); return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); } // ------------------------------ Ne128 template HWY_INLINE Mask128 Ne128(Simd d, Vec128 a, Vec128 b) { static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); const Vec128 neHL = VecFromMask(d, Ne(a, b)); return MaskFromVec(Or(Reverse2(d, neHL), neHL)); } template HWY_INLINE Mask128 Ne128Upper(Simd d, Vec128 a, Vec128 b) { const Vec128 neHL = VecFromMask(d, Ne(a, b)); return MaskFromVec(InterleaveUpper(d, neHL, neHL)); } // ------------------------------ Min128, Max128 (Lt128) // Without a native OddEven, it seems infeasible to go faster than Lt128. template HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { return IfThenElse(Lt128(d, a, b), a, b); } template HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { return IfThenElse(Lt128(d, b, a), a, b); } template HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { return IfThenElse(Lt128Upper(d, a, b), a, b); } template HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { return IfThenElse(Lt128Upper(d, b, a), a, b); } namespace detail { // for code folding #if HWY_ARCH_ARM_V7 #undef vuzp1_s8 #undef vuzp1_u8 #undef vuzp1_s16 #undef vuzp1_u16 #undef vuzp1_s32 #undef vuzp1_u32 #undef vuzp1_f32 #undef vuzp1q_s8 #undef vuzp1q_u8 #undef vuzp1q_s16 #undef vuzp1q_u16 #undef vuzp1q_s32 #undef vuzp1q_u32 #undef vuzp1q_f32 #undef vuzp2_s8 #undef vuzp2_u8 #undef vuzp2_s16 #undef vuzp2_u16 #undef vuzp2_s32 #undef vuzp2_u32 #undef vuzp2_f32 #undef vuzp2q_s8 #undef vuzp2q_u8 #undef vuzp2q_s16 #undef vuzp2q_u16 #undef vuzp2q_s32 #undef vuzp2q_u32 #undef vuzp2q_f32 #undef vzip1_s8 #undef vzip1_u8 #undef vzip1_s16 #undef vzip1_u16 #undef vzip1_s32 #undef vzip1_u32 #undef vzip1_f32 #undef vzip1q_s8 #undef vzip1q_u8 #undef vzip1q_s16 #undef vzip1q_u16 #undef vzip1q_s32 #undef vzip1q_u32 #undef vzip1q_f32 #undef vzip2_s8 #undef vzip2_u8 #undef vzip2_s16 #undef vzip2_u16 #undef vzip2_s32 #undef vzip2_u32 #undef vzip2_f32 #undef vzip2q_s8 #undef vzip2q_u8 #undef vzip2q_s16 #undef vzip2q_u16 #undef vzip2q_s32 #undef vzip2q_u32 #undef vzip2q_f32 #endif #undef HWY_NEON_BUILD_ARG_1 #undef HWY_NEON_BUILD_ARG_2 #undef HWY_NEON_BUILD_ARG_3 #undef HWY_NEON_BUILD_PARAM_1 #undef HWY_NEON_BUILD_PARAM_2 #undef HWY_NEON_BUILD_PARAM_3 #undef HWY_NEON_BUILD_RET_1 #undef HWY_NEON_BUILD_RET_2 #undef HWY_NEON_BUILD_RET_3 #undef HWY_NEON_BUILD_TPL_1 #undef HWY_NEON_BUILD_TPL_2 #undef HWY_NEON_BUILD_TPL_3 #undef HWY_NEON_DEF_FUNCTION #undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS #undef HWY_NEON_DEF_FUNCTION_ALL_TYPES #undef HWY_NEON_DEF_FUNCTION_FLOAT_64 #undef HWY_NEON_DEF_FUNCTION_FULL_UI #undef HWY_NEON_DEF_FUNCTION_INT_16 #undef HWY_NEON_DEF_FUNCTION_INT_32 #undef HWY_NEON_DEF_FUNCTION_INT_8 #undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 #undef HWY_NEON_DEF_FUNCTION_INTS #undef HWY_NEON_DEF_FUNCTION_INTS_UINTS #undef HWY_NEON_DEF_FUNCTION_TPL #undef HWY_NEON_DEF_FUNCTION_UIF81632 #undef HWY_NEON_DEF_FUNCTION_UINT_16 #undef HWY_NEON_DEF_FUNCTION_UINT_32 #undef HWY_NEON_DEF_FUNCTION_UINT_8 #undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 #undef HWY_NEON_DEF_FUNCTION_UINTS #undef HWY_NEON_EVAL } // namespace detail // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();