From 36d22d82aa202bb199967e9512281e9a53db42c9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 21:33:14 +0200 Subject: Adding upstream version 115.7.0esr. Signed-off-by: Daniel Baumann --- third_party/highway/hwy/ops/arm_neon-inl.h | 6810 ++++++++++++++++++++++ third_party/highway/hwy/ops/arm_sve-inl.h | 3186 +++++++++++ third_party/highway/hwy/ops/emu128-inl.h | 2503 +++++++++ third_party/highway/hwy/ops/generic_ops-inl.h | 1560 ++++++ third_party/highway/hwy/ops/rvv-inl.h | 3451 ++++++++++++ third_party/highway/hwy/ops/scalar-inl.h | 1626 ++++++ third_party/highway/hwy/ops/set_macros-inl.h | 444 ++ third_party/highway/hwy/ops/shared-inl.h | 332 ++ third_party/highway/hwy/ops/wasm_128-inl.h | 4591 +++++++++++++++ third_party/highway/hwy/ops/wasm_256-inl.h | 2003 +++++++ third_party/highway/hwy/ops/x86_128-inl.h | 7432 +++++++++++++++++++++++++ third_party/highway/hwy/ops/x86_256-inl.h | 5548 ++++++++++++++++++ third_party/highway/hwy/ops/x86_512-inl.h | 4605 +++++++++++++++ 13 files changed, 44091 insertions(+) create mode 100644 third_party/highway/hwy/ops/arm_neon-inl.h create mode 100644 third_party/highway/hwy/ops/arm_sve-inl.h create mode 100644 third_party/highway/hwy/ops/emu128-inl.h create mode 100644 third_party/highway/hwy/ops/generic_ops-inl.h create mode 100644 third_party/highway/hwy/ops/rvv-inl.h create mode 100644 third_party/highway/hwy/ops/scalar-inl.h create mode 100644 third_party/highway/hwy/ops/set_macros-inl.h create mode 100644 third_party/highway/hwy/ops/shared-inl.h create mode 100644 third_party/highway/hwy/ops/wasm_128-inl.h create mode 100644 third_party/highway/hwy/ops/wasm_256-inl.h create mode 100644 third_party/highway/hwy/ops/x86_128-inl.h create mode 100644 third_party/highway/hwy/ops/x86_256-inl.h create mode 100644 third_party/highway/hwy/ops/x86_512-inl.h (limited to 'third_party/highway/hwy/ops') diff --git a/third_party/highway/hwy/ops/arm_neon-inl.h b/third_party/highway/hwy/ops/arm_neon-inl.h new file mode 100644 index 0000000000..7c3759aa3d --- /dev/null +++ b/third_party/highway/hwy/ops/arm_neon-inl.h @@ -0,0 +1,6810 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit ARM64 NEON vectors and operations. +// External include guard in highway.h - see comment there. + +// ARM NEON intrinsics are documented at: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] + +#include +#include + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); + +// Must come after HWY_BEFORE_NAMESPACE so that the intrinsics are compiled with +// the same target attribute as our code, see #834. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#include // NOLINT(build/include_order) +HWY_DIAGNOSTICS(pop) + +// Must come after arm_neon.h. +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { // for code folding and Raw128 + +// Macros used to define single and double function calls for multiple types +// for full and half vectors. These macros are undefined at the end of the file. + +// HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. +#define HWY_NEON_BUILD_TPL_1 +#define HWY_NEON_BUILD_TPL_2 +#define HWY_NEON_BUILD_TPL_3 + +// HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can +// extend it to int32x4x2_t packs. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_2(type, size) Vec128 +#define HWY_NEON_BUILD_RET_3(type, size) Vec128 + +// HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a +#define HWY_NEON_BUILD_PARAM_2(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128 a, const Vec128 b, \ + const Vec128 c + +// HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying +// function. +#define HWY_NEON_BUILD_ARG_1 a.raw +#define HWY_NEON_BUILD_ARG_2 a.raw, b.raw +#define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw + +// We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after +// the __VA_ARGS__ have been expanded. This allows "func" to be a macro on +// itself like with some of the library "functions" such as vshlq_u8. For +// example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as +// "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. +// Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro +// expects two arguments. +#define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) + +// Main macro definition that defines a single function for the given type and +// size of vector, using the underlying (prefix##infix##suffix) function and +// the template, return type, parameters and arguments defined by the "args" +// parameters passed here (see HWY_NEON_BUILD_* macros defined before). +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ + HWY_API HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ + name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ + return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ + HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ + } + +// The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function +// called "name" using the set of neon functions starting with the given +// "prefix" for all the variants of certain types, as specified next to each +// macro. For example, the prefix "vsub" can be used to define the operator- +// using args=2. + +// uint8_t +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) + +// int8_t +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) + +// uint16_t +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) + +// int16_t +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) + +// uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) + +// int32_t +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) + +// uint64_t +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) + +// int64_t +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) + +// float +#define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) + +// double +#if HWY_ARCH_ARM_A64 +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) +#endif + +// float and double + +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// Helper macros to define for more than one type. +// uint8_t, uint16_t and uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) + +// int8_t, int16_t and int32_t +#define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) + +// uint8_t, uint16_t, uint32_t and uint64_t +#define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) + +// int8_t, int16_t, int32_t and int64_t +#define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) + +// All int*_t and uint*_t up to 64 +#define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) + +// All previous types. +#define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF81632(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) + +// For eor3q, which is only defined for full vectors. +#define HWY_NEON_DEF_FUNCTION_FULL_UI(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) + +// Emulation of some intrinsics on armv7. +#if HWY_ARCH_ARM_V7 +#define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] +#define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] +#define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] +#define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] +#define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] +#define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] +#define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] +#define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] +#define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] +#define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] +#define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] +#define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] +#define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] +#define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] +#define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] +#define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] +#define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] +#define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] +#define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] +#define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] +#define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] +#define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] +#define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] +#define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] +#define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] +#define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] +#define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] +#define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] +#define vzip1_s8(x, y) vzip_s8(x, y).val[0] +#define vzip1_u8(x, y) vzip_u8(x, y).val[0] +#define vzip1_s16(x, y) vzip_s16(x, y).val[0] +#define vzip1_u16(x, y) vzip_u16(x, y).val[0] +#define vzip1_f32(x, y) vzip_f32(x, y).val[0] +#define vzip1_u32(x, y) vzip_u32(x, y).val[0] +#define vzip1_s32(x, y) vzip_s32(x, y).val[0] +#define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] +#define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] +#define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] +#define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] +#define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] +#define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] +#define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] +#define vzip2_s8(x, y) vzip_s8(x, y).val[1] +#define vzip2_u8(x, y) vzip_u8(x, y).val[1] +#define vzip2_s16(x, y) vzip_s16(x, y).val[1] +#define vzip2_u16(x, y) vzip_u16(x, y).val[1] +#define vzip2_s32(x, y) vzip_s32(x, y).val[1] +#define vzip2_u32(x, y) vzip_u32(x, y).val[1] +#define vzip2_f32(x, y) vzip_f32(x, y).val[1] +#define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] +#define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] +#define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] +#define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] +#define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] +#define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] +#define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] +#endif + +// Wrappers over uint8x16x2_t etc. so we can define StoreInterleaved2 overloads +// for all vector types, even those (bfloat16_t) where the underlying vector is +// the same as others (uint16_t). +template +struct Tuple2; +template +struct Tuple3; +template +struct Tuple4; + +template <> +struct Tuple2 { + uint8x16x2_t raw; +}; +template +struct Tuple2 { + uint8x8x2_t raw; +}; +template <> +struct Tuple2 { + int8x16x2_t raw; +}; +template +struct Tuple2 { + int8x8x2_t raw; +}; +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; +template <> +struct Tuple2 { + int16x8x2_t raw; +}; +template +struct Tuple2 { + int16x4x2_t raw; +}; +template <> +struct Tuple2 { + uint32x4x2_t raw; +}; +template +struct Tuple2 { + uint32x2x2_t raw; +}; +template <> +struct Tuple2 { + int32x4x2_t raw; +}; +template +struct Tuple2 { + int32x2x2_t raw; +}; +template <> +struct Tuple2 { + uint64x2x2_t raw; +}; +template +struct Tuple2 { + uint64x1x2_t raw; +}; +template <> +struct Tuple2 { + int64x2x2_t raw; +}; +template +struct Tuple2 { + int64x1x2_t raw; +}; + +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; + +template <> +struct Tuple2 { + float32x4x2_t raw; +}; +template +struct Tuple2 { + float32x2x2_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple2 { + float64x2x2_t raw; +}; +template +struct Tuple2 { + float64x1x2_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <> +struct Tuple3 { + uint8x16x3_t raw; +}; +template +struct Tuple3 { + uint8x8x3_t raw; +}; +template <> +struct Tuple3 { + int8x16x3_t raw; +}; +template +struct Tuple3 { + int8x8x3_t raw; +}; +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; +template <> +struct Tuple3 { + int16x8x3_t raw; +}; +template +struct Tuple3 { + int16x4x3_t raw; +}; +template <> +struct Tuple3 { + uint32x4x3_t raw; +}; +template +struct Tuple3 { + uint32x2x3_t raw; +}; +template <> +struct Tuple3 { + int32x4x3_t raw; +}; +template +struct Tuple3 { + int32x2x3_t raw; +}; +template <> +struct Tuple3 { + uint64x2x3_t raw; +}; +template +struct Tuple3 { + uint64x1x3_t raw; +}; +template <> +struct Tuple3 { + int64x2x3_t raw; +}; +template +struct Tuple3 { + int64x1x3_t raw; +}; + +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; + +template <> +struct Tuple3 { + float32x4x3_t raw; +}; +template +struct Tuple3 { + float32x2x3_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple3 { + float64x2x3_t raw; +}; +template +struct Tuple3 { + float64x1x3_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <> +struct Tuple4 { + uint8x16x4_t raw; +}; +template +struct Tuple4 { + uint8x8x4_t raw; +}; +template <> +struct Tuple4 { + int8x16x4_t raw; +}; +template +struct Tuple4 { + int8x8x4_t raw; +}; +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; +template <> +struct Tuple4 { + int16x8x4_t raw; +}; +template +struct Tuple4 { + int16x4x4_t raw; +}; +template <> +struct Tuple4 { + uint32x4x4_t raw; +}; +template +struct Tuple4 { + uint32x2x4_t raw; +}; +template <> +struct Tuple4 { + int32x4x4_t raw; +}; +template +struct Tuple4 { + int32x2x4_t raw; +}; +template <> +struct Tuple4 { + uint64x2x4_t raw; +}; +template +struct Tuple4 { + uint64x1x4_t raw; +}; +template <> +struct Tuple4 { + int64x2x4_t raw; +}; +template +struct Tuple4 { + int64x1x4_t raw; +}; + +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; + +template <> +struct Tuple4 { + float32x4x4_t raw; +}; +template +struct Tuple4 { + float32x2x4_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple4 { + float64x2x4_t raw; +}; +template +struct Tuple4 { + float64x1x4_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template +struct Raw128; + +// 128 +template <> +struct Raw128 { + using type = uint8x16_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = uint32x4_t; +}; + +template <> +struct Raw128 { + using type = uint64x2_t; +}; + +template <> +struct Raw128 { + using type = int8x16_t; +}; + +template <> +struct Raw128 { + using type = int16x8_t; +}; + +template <> +struct Raw128 { + using type = int32x4_t; +}; + +template <> +struct Raw128 { + using type = int64x2_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = float32x4_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128 { + using type = float64x2_t; +}; +#endif + +// 64 +template <> +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint32x2_t; +}; + +template <> +struct Raw128 { + using type = uint64x1_t; +}; + +template <> +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = int32x2_t; +}; + +template <> +struct Raw128 { + using type = int64x1_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = float32x2_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128 { + using type = float64x1_t; +}; +#endif + +// 32 (same as 64) +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +// 16 (same as 64) +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +// 8 (same as 64) +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() {} + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +// FF..FF or 0. +template +class Mask128 { + // ARM C Language Extensions return and expect unsigned type. + using Raw = typename detail::Raw128, N>::type; + + public: + HWY_INLINE Mask128() {} + Mask128(const Mask128&) = default; + Mask128& operator=(const Mask128&) = default; + HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} + + Raw raw; +}; + +template +using Mask64 = Mask128; + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +// Converts from Vec128 to Vec128 using the +// vreinterpret*_u8_*() set of functions. +#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ + Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw + +// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return v; +} + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) + +// Special cases for [b]float16_t, which have the same Raw as uint16_t. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} + +#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return v; +} + +// 64-bit or less: + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s8_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_f32_u8(v.raw)); +} +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_u64_u8(v.raw)); +} +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_s64_u8(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_f64_u8(v.raw)); +} +#endif + +// 128-bit full: + +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s8_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u16_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s16_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_f32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u64_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s64_u8(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_f64_u8(v.raw)); +} +#endif + +// Special cases for [b]float16_t, which have the same Raw as uint16_t. +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(BitCastFromByte(Simd(), v).raw); +} +template +HWY_INLINE Vec128 BitCastFromByte( + Simd /* tag */, Vec128 v) { + return Vec128(BitCastFromByte(Simd(), v).raw); +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns a vector with all lanes set to "t". +#define HWY_NEON_BUILD_TPL_HWY_SET1 +#define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_SET1(type, size) \ + Simd /* tag */, const type##_t t +#define HWY_NEON_BUILD_ARG_HWY_SET1 t + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) + +#undef HWY_NEON_BUILD_TPL_HWY_SET1 +#undef HWY_NEON_BUILD_RET_HWY_SET1 +#undef HWY_NEON_BUILD_PARAM_HWY_SET1 +#undef HWY_NEON_BUILD_ARG_HWY_SET1 + +// Returns an all-zero vector. +template +HWY_API Vec128 Zero(Simd d) { + return Set(d, 0); +} + +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128(Zero(Simd()).raw); +} + +template +using VFromD = decltype(Zero(D())); + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd /*d*/) { + typename detail::Raw128::type a; + return Vec128(a); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ GetLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_GET template +#define HWY_NEON_BUILD_RET_HWY_GET(type, size) type##_t +#define HWY_NEON_BUILD_PARAM_HWY_GET(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) + +#undef HWY_NEON_BUILD_TPL_HWY_GET +#undef HWY_NEON_BUILD_RET_HWY_GET +#undef HWY_NEON_BUILD_PARAM_HWY_GET +#undef HWY_NEON_BUILD_ARG_HWY_GET + +} // namespace detail + +template +HWY_API TFromV GetLane(const V v) { + return detail::GetLane<0>(v); +} + +// ------------------------------ ExtractLane + +// Requires one overload per vector length because GetLane<3> is a compile error +// if v is a uint32x2_t. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::GetLane<0>(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + case 8: + return detail::GetLane<8>(v); + case 9: + return detail::GetLane<9>(v); + case 10: + return detail::GetLane<10>(v); + case 11: + return detail::GetLane<11>(v); + case 12: + return detail::GetLane<12>(v); + case 13: + return detail::GetLane<13>(v); + case 14: + return detail::GetLane<14>(v); + case 15: + return detail::GetLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_INSERT template +#define HWY_NEON_BUILD_RET_HWY_INSERT(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_INSERT(type, size) \ + Vec128 v, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) + +#undef HWY_NEON_BUILD_TPL_HWY_INSERT +#undef HWY_NEON_BUILD_RET_HWY_INSERT +#undef HWY_NEON_BUILD_PARAM_HWY_INSERT +#undef HWY_NEON_BUILD_ARG_HWY_INSERT + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) + +// ------------------------------ Subtraction +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) + +// ------------------------------ SumsOf8 + +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); +} + +// ------------------------------ SaturatedAdd +// Only defined for uint8_t, uint16_t and their signed versions, as in other +// architectures. + +// Returns a + b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedAdd, vqadd, _, 2) + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedSub, vqsub, _, 2) + +// Not part of API, used in implementation. +namespace detail { +HWY_NEON_DEF_FUNCTION_UINT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_64(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_64(SaturatedSub, vqsub, _, 2) +} // namespace detail + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 +HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) + +// ------------------------------ Neg + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below + +HWY_API Vec64 Neg(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vneg_s64(v.raw)); +#else + return Zero(Full64()) - v; +#endif +} + +HWY_API Vec128 Neg(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vnegq_s64(v.raw)); +#else + return Zero(Full128()) - v; +#endif +} + +// ------------------------------ ShiftLeft + +// Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + template \ + HWY_API Vec128 name(const Vec128 v) { \ + return kBits == 0 ? v \ + : Vec128(HWY_NEON_EVAL( \ + prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ + } + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) + +HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ RotateRight (ShiftRight, Or) + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// NOTE: vxarq_u64 can be applied to uint64_t, but we do not yet have a +// mechanism for checking for extensions to ARMv8. + +// ------------------------------ Shl + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); +} +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s8(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s8(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s16(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s16(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s32(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s32(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s64(v.raw, bits.raw)); +} +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_s64(v.raw, bits.raw)); +} + +// ------------------------------ Shr (Neg) + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int8x16_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int8x8_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int16x8_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int16x4_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int32x4_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int32x2_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int64x2_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + const int64x1_t neg_bits = Neg(BitCast(Full64(), bits)).raw; + return Vec64(vshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ ShiftLeftSame (Shl) + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { + return v << Set(Simd(), static_cast(bits)); +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { + return v >> Set(Simd(), static_cast(bits)); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_u16(a.raw, b.raw)); +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_u32(a.raw, b.raw)); +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_u16(a.raw, b.raw)); +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_u32(a.raw, b.raw)); +} + +// Signed +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_s16(a.raw, b.raw)); +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_s32(a.raw, b.raw)); +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_s16(a.raw, b.raw)); +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_s32(a.raw, b.raw)); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); +#if HWY_ARCH_ARM_A64 + int32x4_t rhi = vmull_high_s16(a.raw, b.raw); +#else + int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); +#endif + return Vec128( + vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); +} +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); +#if HWY_ARCH_ARM_A64 + uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); +#else + uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); +#endif + return Vec128( + vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); +} + +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); + return Vec128(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); + return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); +} + +HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { + return Vec128(vqrdmulhq_s16(a.raw, b.raw)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128(vqrdmulh_s16(a.raw, b.raw)); +} + +// ------------------------------ Floating-point mul / div + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) + +// Approximate reciprocal +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128(vrecpeq_f32(v.raw)); +} +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128(vrecpe_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) +#else +// Not defined on armv7: approximate +namespace detail { + +HWY_INLINE Vec128 ReciprocalNewtonRaphsonStep( + const Vec128 recip, const Vec128 divisor) { + return Vec128(vrecpsq_f32(recip.raw, divisor.raw)); +} +template +HWY_INLINE Vec128 ReciprocalNewtonRaphsonStep( + const Vec128 recip, Vec128 divisor) { + return Vec128(vrecps_f32(recip.raw, divisor.raw)); +} + +} // namespace detail + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + auto x = ApproximateReciprocal(b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + return a * x; +} +#endif + +// ------------------------------ Absolute value of difference. + +HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { + return Vec128(vabdq_f32(a.raw, b.raw)); +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Vec128(vabd_f32(a.raw, b.raw)); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns add + mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfma_f32(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmaq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return mul * x + add; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 MulAdd(const Vec64 mul, const Vec64 x, + const Vec64 add) { + return Vec64(vfma_f64(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmaq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns add - mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfms_f32(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmsq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return add - mul * x; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 NegMulAdd(const Vec64 mul, const Vec64 x, + const Vec64 add) { + return Vec64(vfms_f64(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfmsq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} + +#if HWY_ARCH_ARM_A64 +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} +#endif + +// ------------------------------ Floating-point square root (IfThenZeroElse) + +// Approximate reciprocal square root +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128(vrsqrteq_f32(v.raw)); +} +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128(vrsqrte_f32(v.raw)); +} + +// Full precision square root +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) +#else +namespace detail { + +HWY_INLINE Vec128 ReciprocalSqrtStep(const Vec128 root, + const Vec128 recip) { + return Vec128(vrsqrtsq_f32(root.raw, recip.raw)); +} +template +HWY_INLINE Vec128 ReciprocalSqrtStep(const Vec128 root, + Vec128 recip) { + return Vec128(vrsqrts_f32(root.raw, recip.raw)); +} + +} // namespace detail + +// Not defined on armv7: approximate +template +HWY_API Vec128 Sqrt(const Vec128 v) { + auto recip = ApproximateReciprocalSqrt(v); + + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + + const auto root = v * recip; + return IfThenZeroElse(v == Zero(Simd()), root); +} +#endif + +// ================================================== LOGICAL + +// ------------------------------ Not + +// There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. +template +HWY_API Vec128 Not(const Vec128 v) { + const Full128 d; + const Repartition d8; + return BitCast(d, Vec128(vmvnq_u8(BitCast(d8, v).raw))); +} +template +HWY_API Vec128 Not(const Vec128 v) { + const Simd d; + const Repartition d8; + using V8 = decltype(Zero(d8)); + return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); +} + +// ------------------------------ And +HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 And(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) & BitCast(du, b)); +} + +// ------------------------------ AndNot + +namespace detail { +// reversed_andnot returns a & ~b. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) +} // namespace detail + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return detail::reversed_andnot(mask, not_mask); +} + +// Uses the u32/64 defined above. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + VFromD ret = + detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(d, ret); +} + +// ------------------------------ Or + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) | BitCast(du, b)); +} + +// ------------------------------ Xor + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); +} + +// ------------------------------ Xor3 +#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3) +HWY_NEON_DEF_FUNCTION_FULL_UI(Xor3, veor3, _, 3) + +// Half vectors are not natively supported. Two Xor are likely more efficient +// than Combine to 128-bit. +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +template +HWY_API Vec128 Xor3(const Vec128 x1, const Vec128 x2, + const Vec128 x3) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#else +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} +#endif + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { + const Full128 d8; + return Vec128(vcntq_u8(BitCast(d8, v).raw)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + const Simd d8; + return Vec128(vcnt_u8(BitCast(d8, v).raw)); +} + +// ARM lacks popcount for lane sizes > 1, so take pairwise sums of the bytes. +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u8(bytes)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u8(bytes)); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u16(vpaddlq_u8(bytes))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u16(vpaddl_u8(bytes))); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(bytes)))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +// ================================================== SIGN + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s8(v.raw)); +} +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s16(v.raw)); +} +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s32(v.raw)); +} +// i64 is implemented after BroadcastSignBit. +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_f32(v.raw)); +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s8(v.raw)); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s16(v.raw)); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s32(v.raw)); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_f64(v.raw)); +} + +HWY_API Vec64 Abs(const Vec64 v) { + return Vec64(vabs_f64(v.raw)); +} +#endif + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} + +// ================================================== MASK + +// ------------------------------ To/from vector + +// Mask and Vec have the same representation (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const Simd, N, 0> du; + return Mask128(BitCast(du, v).raw); +} + +template +HWY_API Vec128 VecFromMask(Simd d, const Mask128 v) { + return BitCast(d, Vec128, N>(v.raw)); +} + +// ------------------------------ RebindMask + +template +HWY_API Mask128 RebindMask(Simd dto, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(dto, VecFromMask(Simd(), m))); +} + +// ------------------------------ IfThenElse(mask, yes, no) = mask ? b : a. + +#define HWY_NEON_BUILD_TPL_HWY_IF +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128 mask, const Vec128 yes, \ + const Vec128 no +#define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) + +#undef HWY_NEON_BUILD_TPL_HWY_IF +#undef HWY_NEON_BUILD_RET_HWY_IF +#undef HWY_NEON_BUILD_PARAM_HWY_IF +#undef HWY_NEON_BUILD_ARG_HWY_IF + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(const Mask128 mask, + const Vec128 yes) { + return yes & VecFromMask(Simd(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(const Mask128 mask, + const Vec128 no) { + return AndNot(VecFromMask(Simd(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Simd d; + const RebindToSigned di; + + Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + const auto zero = Zero(Simd()); + return Max(zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// ------------------------------ Shuffle2301 (for i64 compares) + +// Swap 32-bit halves in 64-bits +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_u32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_s32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_f32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_u32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_s32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_f32(v.raw)); +} + +#define HWY_NEON_BUILD_TPL_HWY_COMPARE +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw + +// ------------------------------ Equality +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) +#else +// No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) +#endif + +// ------------------------------ Strict inequality (signed, float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<, vclt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<, vclt, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) + +// ------------------------------ Weak inequality (float) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) + +#undef HWY_NEON_BUILD_TPL_HWY_COMPARE +#undef HWY_NEON_BUILD_RET_HWY_COMPARE +#undef HWY_NEON_BUILD_PARAM_HWY_COMPARE +#undef HWY_NEON_BUILD_ARG_HWY_COMPARE + +// ------------------------------ ARMv7 i64 compare (Shuffle2301, Eq) + +#if HWY_ARCH_ARM_V7 + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const int64x2_t sub = vqsubq_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128(sub))); +} +HWY_API Mask128 operator<(const Vec64 a, + const Vec64 b) { + const int64x1_t sub = vqsub_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec64(sub))); +} + +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = AndNot(a, b) | AndNot(a ^ b, a - b); + return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); +} + +#endif + +// ------------------------------ operator!= (operator==) + +// Customize HWY_NEON_DEF_FUNCTION to call 2 functions. +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +// This cannot have _any_ template argument (in x86_128 we can at least have N +// as an argument), otherwise it is not more specialized than rewritten +// operator== in C++20, leading to compile errors. +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_API Mask128 name(Vec128 a, \ + Vec128 b) { \ + return Not(a == b); \ + } + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator!=, ignored, ignored, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return operator<(b, a); +} +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return operator<=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ------------------------------ TestBit (Eq) + +#define HWY_NEON_BUILD_TPL_HWY_TESTBIT +#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ + Vec128 v, Vec128 bit +#define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(TestBit, vtst, _, HWY_TESTBIT) +#else +// No 64-bit versions on armv7 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) + +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} + +#endif +#undef HWY_NEON_BUILD_TPL_HWY_TESTBIT +#undef HWY_NEON_BUILD_RET_HWY_TESTBIT +#undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT +#undef HWY_NEON_BUILD_ARG_HWY_TESTBIT + +// ------------------------------ Abs i64 (IfThenElse, BroadcastSignBit) +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vabsq_s64(v.raw)); +#else + const auto zero = Zero(Full128()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} +HWY_API Vec64 Abs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vabs_s64(v.raw)); +#else + const auto zero = Zero(Full64()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +// ------------------------------ Min (IfThenElse, BroadcastSignBit) + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, a) - BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const Vec128 sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); +#endif +} + +// Float: IEEE minimumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vminnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) +#endif + +// ------------------------------ Max (IfThenElse, BroadcastSignBit) + +// Unsigned (no u64) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, b) + BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed (no i64) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const Vec128 sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); +#endif +} + +// Float: IEEE maximumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmaxnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) +#endif + +// ================================================== MEMORY + +// ------------------------------ Load 128 + +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u8(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u16(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u32(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u64(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s8(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s16(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s32(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s64(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f32(unaligned)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f64(unaligned)); +} +#endif + +// ------------------------------ Load 64 + +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint8_t* HWY_RESTRICT p) { + return Vec64(vld1_u8(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint16_t* HWY_RESTRICT p) { + return Vec64(vld1_u16(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint32_t* HWY_RESTRICT p) { + return Vec64(vld1_u32(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint64_t* HWY_RESTRICT p) { + return Vec64(vld1_u64(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int8_t* HWY_RESTRICT p) { + return Vec64(vld1_s8(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int16_t* HWY_RESTRICT p) { + return Vec64(vld1_s16(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int32_t* HWY_RESTRICT p) { + return Vec64(vld1_s32(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int64_t* HWY_RESTRICT p) { + return Vec64(vld1_s64(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const float* HWY_RESTRICT p) { + return Vec64(vld1_f32(p)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 LoadU(Full64 /* tag */, + const double* HWY_RESTRICT p) { + return Vec64(vld1_f64(p)); +} +#endif +// ------------------------------ Load 32 + +// Actual 32-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +HWY_API Vec32 LoadU(Full32 /*tag*/, + const uint32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_u32(p)); +} +HWY_API Vec32 LoadU(Full32 /*tag*/, + const int32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_s32(p)); +} +HWY_API Vec32 LoadU(Full32 /*tag*/, const float* HWY_RESTRICT p) { + return Vec32(vld1_dup_f32(p)); +} + +template // 1 or 2 bytes +HWY_API Vec32 LoadU(Full32 d, const T* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} + +// ------------------------------ Load 16 + +// Actual 16-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +HWY_API Vec128 LoadU(Simd /*tag*/, + const uint16_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_u16(p)); +} +HWY_API Vec128 LoadU(Simd /*tag*/, + const int16_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_s16(p)); +} + +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + const Repartition d16; + uint16_t buf; + CopyBytes<2>(p, &buf); + return BitCast(d, LoadU(d16, &buf)); +} + +// ------------------------------ Load 8 + +HWY_API Vec128 LoadU(Simd, + const uint8_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_u8(p)); +} + +HWY_API Vec128 LoadU(Simd, + const int8_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_s8(p)); +} + +// [b]float16_t use the same Raw as uint16_t, so forward to that. +template +HWY_API Vec128 LoadU(Simd d, + const float16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return Vec128(LoadU(du16, pu16).raw); +} +template +HWY_API Vec128 LoadU(Simd d, + const bfloat16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return Vec128(LoadU(du16, pu16).raw); +} + +// On ARM, Load is the same as LoadU. +template +HWY_API Vec128 Load(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, + const T* const HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store 128 + +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint8_t* HWY_RESTRICT unaligned) { + vst1q_u8(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint16_t* HWY_RESTRICT unaligned) { + vst1q_u16(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint32_t* HWY_RESTRICT unaligned) { + vst1q_u32(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint64_t* HWY_RESTRICT unaligned) { + vst1q_u64(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int8_t* HWY_RESTRICT unaligned) { + vst1q_s8(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int16_t* HWY_RESTRICT unaligned) { + vst1q_s16(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int32_t* HWY_RESTRICT unaligned) { + vst1q_s32(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int64_t* HWY_RESTRICT unaligned) { + vst1q_s64(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT unaligned) { + vst1q_f32(unaligned, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT unaligned) { + vst1q_f64(unaligned, v.raw); +} +#endif + +// ------------------------------ Store 64 + +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint8_t* HWY_RESTRICT p) { + vst1_u8(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint16_t* HWY_RESTRICT p) { + vst1_u16(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint32_t* HWY_RESTRICT p) { + vst1_u32(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint64_t* HWY_RESTRICT p) { + vst1_u64(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int8_t* HWY_RESTRICT p) { + vst1_s8(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int16_t* HWY_RESTRICT p) { + vst1_s16(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int32_t* HWY_RESTRICT p) { + vst1_s32(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int64_t* HWY_RESTRICT p) { + vst1_s64(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + float* HWY_RESTRICT p) { + vst1_f32(p, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + double* HWY_RESTRICT p) { + vst1_f64(p, v.raw); +} +#endif + +// ------------------------------ Store 32 + +HWY_API void StoreU(const Vec32 v, Full32, + uint32_t* HWY_RESTRICT p) { + vst1_lane_u32(p, v.raw, 0); +} +HWY_API void StoreU(const Vec32 v, Full32, + int32_t* HWY_RESTRICT p) { + vst1_lane_s32(p, v.raw, 0); +} +HWY_API void StoreU(const Vec32 v, Full32, + float* HWY_RESTRICT p) { + vst1_lane_f32(p, v.raw, 0); +} + +template // 1 or 2 bytes +HWY_API void StoreU(const Vec32 v, Full32 d, T* HWY_RESTRICT p) { + const Repartition d32; + const uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} + +// ------------------------------ Store 16 + +HWY_API void StoreU(const Vec128 v, Simd, + uint16_t* HWY_RESTRICT p) { + vst1_lane_u16(p, v.raw, 0); +} +HWY_API void StoreU(const Vec128 v, Simd, + int16_t* HWY_RESTRICT p) { + vst1_lane_s16(p, v.raw, 0); +} + +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + const Repartition d16; + const uint16_t buf = GetLane(BitCast(d16, v)); + CopyBytes<2>(&buf, p); +} + +// ------------------------------ Store 8 + +HWY_API void StoreU(const Vec128 v, Simd, + uint8_t* HWY_RESTRICT p) { + vst1_lane_u8(p, v.raw, 0); +} +HWY_API void StoreU(const Vec128 v, Simd, + int8_t* HWY_RESTRICT p) { + vst1_lane_s8(p, v.raw, 0); +} + +// [b]float16_t use the same Raw as uint16_t, so forward to that. +template +HWY_API void StoreU(Vec128 v, Simd d, + float16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return StoreU(Vec128(v.raw), du16, pu16); +} +template +HWY_API void StoreU(Vec128 v, Simd d, + bfloat16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return StoreU(Vec128(v.raw), du16, pu16); +} + +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// On ARM, Store is the same as StoreU. +template +HWY_API void Store(Vec128 v, Simd d, T* HWY_RESTRICT aligned) { + StoreU(v, d, aligned); +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + // Treat as unsigned so that we correctly support float16. + const RebindToUnsigned du; + const auto blended = + IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); + StoreU(BitCast(d, blended), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(const Vec128 v, Simd d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend to full vector. +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_u8(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vmovl_u16(vget_low_u16(a))); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_u16(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_u32(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { + return BitCast(d, Vec128(vmovl_u8(v.raw))); +} +HWY_API Vec128 PromoteTo(Full128 d, const Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return BitCast(d, Vec128(vmovl_u16(vget_low_u16(a)))); +} +HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { + return BitCast(d, Vec128(vmovl_u16(v.raw))); +} + +// Unsigned: zero-extend to half vector. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u16(vmovl_u8(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vget_low_u32(vmovl_u16(vget_low_u16(a)))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u32(vmovl_u16(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u64(vmovl_u32(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd d, + const Vec128 v) { + return BitCast(d, Vec128(vget_low_u16(vmovl_u8(v.raw)))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + uint32x4_t b = vmovl_u16(vget_low_u16(a)); + return Vec128(vget_low_s32(vreinterpretq_s32_u32(b))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint32x4_t a = vmovl_u16(v.raw); + return Vec128(vget_low_s32(vreinterpretq_s32_u32(a))); +} + +// Signed: replicate sign bit to full vector. +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_s8(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec32 v) { + int16x8_t a = vmovl_s8(v.raw); + return Vec128(vmovl_s16(vget_low_s16(a))); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_s16(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_s32(v.raw)); +} + +// Signed: replicate sign bit to half vector. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s16(vmovl_s8(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + int16x8_t a = vmovl_s8(v.raw); + int32x4_t b = vmovl_s16(vget_low_s16(a)); + return Vec128(vget_low_s32(b)); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s32(vmovl_s16(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s64(vmovl_s32(v.raw))); +} + +#if __ARM_FP & 2 + +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); + return Vec128(f32); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); + return Vec128(vget_low_f32(f32)); +} + +#else + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +#endif + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vcvt_f64_f32(v.raw)); +} + +HWY_API Vec64 PromoteTo(Full64 /* tag */, + const Vec32 v) { + return Vec64(vget_low_f64(vcvt_f64_f32(v.raw))); +} + +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + const int64x2_t i64 = vmovl_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +HWY_API Vec64 PromoteTo(Full64 /* tag */, + const Vec32 v) { + const int64x1_t i64 = vget_low_s64(vmovl_s32(v.raw)); + return Vec64(vcvt_f64_s64(i64)); +} + +#endif + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +// From full vector to half or quarter +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovun_s32(v.raw)); +} +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovn_s32(v.raw)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec128 v) { + const uint16x4_t a = vqmovun_s32(v.raw); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); +} +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovun_s16(v.raw)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec128 v) { + const int16x4_t a = vqmovn_s32(v.raw); + return Vec32(vqmovn_s16(vcombine_s16(a, a))); +} +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovn_s16(v.raw)); +} + +// From half vector to partial half +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); + return Vec128(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); + return Vec128(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s16(vcombine_s16(v.raw, v.raw))); +} + +#if __ARM_FP & 2 + +HWY_API Vec128 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec128{vreinterpret_u16_f16(vcvt_f16_f32(v.raw))}; +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const float16x4_t f16 = vcvt_f16_f32(vcombine_f32(v.raw, v.raw)); + return Vec128(vreinterpret_u16_f16(f16)); +} + +#else + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128(DemoteTo(du16, bits16).raw); +} + +#endif + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { + return Vec64(vcvt_f32_f64(v.raw)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec64 v) { + return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +} + +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + const int64x2_t i64 = vcvtq_s64_f64(v.raw); + return Vec64(vqmovn_s64(i64)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec64 v) { + const int64x1_t i64 = vcvt_s64_f64(v.raw); + // There is no i64x1 -> i32x1 narrow, so expand to int64x2_t first. + const int64x2_t i64x2 = vcombine_s64(i64, i64); + return Vec32(vqmovn_s64(i64x2)); +} + +#endif + +HWY_API Vec32 U8FromU32(const Vec128 v) { + const uint8x16_t org_v = detail::BitCastToByte(v).raw; + const uint8x16_t w = vuzp1q_u8(org_v, org_v); + return Vec32(vget_low_u8(vuzp1q_u8(w, w))); +} +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const uint8x8_t org_v = detail::BitCastToByte(v).raw; + const uint8x8_t w = vuzp1_u8(org_v, org_v); + return Vec128(vuzp1_u8(w, w)); +} + +// In the following DemoteTo functions, |b| is purposely undefined. +// The value a needs to be extended to 128 bits so that vqmovn can be +// used and |b| is undefined so that no extra overhead is introduced. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 a = DemoteTo(Simd(), v); + Vec128 b; + uint16x8_t c = vcombine_u16(a.raw, b.raw); + return Vec128(vqmovn_u16(c)); +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 a = DemoteTo(Simd(), v); + Vec128 b; + int16x8_t c = vcombine_s16(a.raw, b.raw); + return Vec128(vqmovn_s16(c)); +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Convert integer <=> floating-point + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f32_s32(v.raw)); +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_s32(v.raw)); +} + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f32_u32(v.raw)); +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_u32(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_s32_f32(v.raw)); +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_s32_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f64_s64(v.raw)); +} +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_f64_s64(v.raw)); +} + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f64_u64(v.raw)); +} +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_f64_u64(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_s64_f64(v.raw)); +} +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_s64_f64(v.raw)); +} + +#endif + +// ------------------------------ Round (IfThenElse, mask, logical) + +#if HWY_ARCH_ARM_A64 +// Toward nearest integer +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) + +// Toward zero, aka truncate +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) + +// Toward +infinity, aka ceiling +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) + +// Toward -infinity, aka floor +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) +#else + +// ------------------------------ Trunc + +// ARMv7 only supports truncation to integer. We can either convert back to +// float (3 floating-point and 2 logic operations) or manipulate the binary32 +// representation, clearing the lowest 23-exp mantissa bits. This requires 9 +// integer operations and 3 constants, which is likely more expensive. + +namespace detail { + +// The original value is already the desired result if NaN or the magnitude is +// large (i.e. the value is already an integer). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +template +HWY_API Vec128 Trunc(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), int_f, v); +} + +template +HWY_API Vec128 Round(const Vec128 v) { + const DFromV df; + + // ARMv7 also lacks a native NearestInt, but we can instead rely on rounding + // (we assume the current mode is nearest-even) after addition with a large + // value such that no mantissa bits remain. We may need a compiler flag for + // precise floating-point to prevent this from being "optimized" out. + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +template +HWY_API Vec128 Ceil(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +template +HWY_API Vec128 Floor(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#endif + +// ------------------------------ NearestInt (Round) + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s32_f32(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s32_f32(v.raw)); +} + +#else + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const RebindToSigned> di; + return ConvertTo(di, Round(v)); +} + +#endif + +// ------------------------------ Floating-point classification +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +// <= 64 bit: just return different type +template +HWY_API Vec128 LowerHalf(const Vec128 v) { + return Vec128(v.raw); +} + +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u8(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u16(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u32(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u64(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s8(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s16(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s32(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s64(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_f64(v.raw)); +} +#endif +HWY_API Vec64 LowerHalf(const Vec128 v) { + const Full128 du; + const Full64 dbh; + return BitCast(dbh, LowerHalf(BitCast(du, v))); +} + +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return LowerHalf(v); +} + +// ------------------------------ CombineShiftRightBytes + +// 128-bit +template > +HWY_API V128 CombineShiftRightBytes(Full128 d, V128 hi, V128 lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); + const Repartition d8; + uint8x16_t v8 = vextq_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, Vec128(v8)); +} + +// 64-bit +template +HWY_API Vec64 CombineShiftRightBytes(Full64 d, Vec64 hi, Vec64 lo) { + static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); + const Repartition d8; + uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, VFromD(v8)); +} + +// <= 32-bit defined after ShiftLeftBytes. + +// ------------------------------ Shift vector by constant #bytes + +namespace detail { + +// Partially specialize because kBytes = 0 and >= size are compile errors; +// callers replace the latter with 0xFF for easier specialization. +template +struct ShiftLeftBytesT { + // Full + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + const Full128 d; + return CombineShiftRightBytes<16 - kBytes>(d, v, Zero(d)); + } + + // Partial + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + // Expand to 64-bit so we only use the native EXT instruction. + const Full64 d64; + const auto zero64 = Zero(d64); + const decltype(zero64) v64(v.raw); + return Vec128( + CombineShiftRightBytes<8 - kBytes>(d64, v64, zero64).raw); + } +}; +template <> +struct ShiftLeftBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftLeftBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 /* v */) { + return Zero(Simd()); + } +}; + +template +struct ShiftRightBytesT { + template + HWY_INLINE Vec128 operator()(Vec128 v) { + const Simd d; + // For < 64-bit vectors, zero undefined lanes so we shift in zeros. + if (N * sizeof(T) < 8) { + constexpr size_t kReg = N * sizeof(T) == 16 ? 16 : 8; + const Simd dreg; + v = Vec128( + IfThenElseZero(FirstN(dreg, N), VFromD(v.raw)).raw); + } + return CombineShiftRightBytes(d, Zero(d), v); + } +}; +template <> +struct ShiftRightBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftRightBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 /* v */) { + return Zero(Simd()); + } +}; + +} // namespace detail + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + return detail::ShiftLeftBytesT < kBytes >= N * sizeof(T) ? 0xFF + : kBytes > ()(v); +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(Simd(), v); +} + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(Simd(), v); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + return detail::ShiftRightBytesT < kBytes >= N * sizeof(T) ? 0xFF + : kBytes > ()(v); +} + +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// Calls ShiftLeftBytes +template +HWY_API Vec128 CombineShiftRightBytes(Simd d, Vec128 hi, + Vec128 lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full64 d_full8; + const Repartition d_full; + using V64 = VFromD; + const V64 hi64(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V64 lo64 = ShiftLeftBytes<8 - kSize>(V64(BitCast(d8, lo).raw)); + const V64 r = CombineShiftRightBytes<8 - kSize + kBytes>(d_full8, hi64, lo64); + // After casting to full 64-bit vector of correct type, shrink to 32-bit + return Vec128(BitCast(d_full, r).raw); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u8(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u16(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u32(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u64(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s8(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s16(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s32(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s64(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64(vget_high_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_f64(v.raw)); +} +#endif + +HWY_API Vec64 UpperHalf(Full64 dbh, + const Vec128 v) { + const RebindToUnsigned duh; + const Twice du; + return BitCast(dbh, UpperHalf(duh, BitCast(du, v))); +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128(upper.raw); +} + +// ------------------------------ Broadcast/splat any lane + +#if HWY_ARCH_ARM_A64 +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_u64(v.raw, kLane)); +} +// Vec64 is defined below. + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_s64(v.raw, kLane)); +} +// Vec64 is defined below. + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_f64(v.raw, kLane)); +} +template +HWY_API Vec64 Broadcast(const Vec64 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +#else +// No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); +} +// Vec64 is defined below. + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); +} +// Vec64 is defined below. + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} + +#endif + +template +HWY_API Vec64 Broadcast(const Vec64 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} +template +HWY_API Vec64 Broadcast(const Vec64 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + typename detail::Raw128::type raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#endif + + const Repartition d8; + using V8 = VFromD; + const Repartition d16; + + // Broadcast each lane index to all bytes of T and shift to bytes + static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); + if (sizeof(T) == 4) { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); + return Indices128{BitCast(d, sum).raw}; + } else { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); + return Indices128{BitCast(d, sum).raw}; + } +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const RebindToSigned di; + return BitCast( + d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128{idx.raw}))); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return Vec128(Shuffle2301(v)); +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32_u16(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u16(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64_u32(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u32(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64_u16(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse4(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u16(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { + return Reverse(d, v); +} + +template +HWY_API Vec128 Reverse8(Simd, const Vec128) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ Other shuffles (TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return CombineShiftRightBytes<8>(Full128(), v, v); +} +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + return CombineShiftRightBytes<8>(Full128(), v, v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return CombineShiftRightBytes<4>(Full128(), v, v); +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return CombineShiftRightBytes<12>(Full128(), v, v); +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveLower, vzip1, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveLower, vzip1, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_u64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_s64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_f64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), b, Shuffle01(a)); +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), b, Shuffle01(a)); +} +#endif + +// Floats +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_f32(a.raw, b.raw)); +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1_f32(a.raw, b.raw)); +} + +// < 64 bit parts +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128(InterleaveLower(Vec64(a.raw), Vec64(b.raw)).raw); +} + +// Additional overload for the optional Simd<> tag. +template > +HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveUpper, vzip2, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveUpper, vzip2, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vzip2q_u64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return Vec128(vzip2q_s64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return Vec128(vzip2q_f64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), Shuffle01(b), a); +} +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), Shuffle01(b), a); +} +#endif + +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return Vec128(vzip2q_f32(a.raw, b.raw)); +} +HWY_API Vec64 InterleaveUpper(const Vec64 a, + const Vec64 b) { + return Vec64(vzip2_f32(a.raw, b.raw)); +} + +} // namespace detail + +// Full register +template > +HWY_API V InterleaveUpper(Simd /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V(UpperHalf(d2, a).raw), V(UpperHalf(d2, b).raw)); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + const Rebind du32; + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Avoid ZipLower/Upper so this also works on big-endian systems. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +HWY_API Vec128 ReorderWidenMulAccumulate(Full128 /*d32*/, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmlal_high_s16(sum1.raw, a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128( + vmlal_s16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128( + vmlal_s16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +HWY_API Vec64 ReorderWidenMulAccumulate(Full64 d32, + Vec64 a, + Vec64 b, + const Vec64 sum0, + Vec64& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64 mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +HWY_API Vec32 ReorderWidenMulAccumulate(Full32 d32, + Vec32 a, + Vec32 b, + const Vec32 sum0, + Vec32& sum1) { + const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// Full result +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u8(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, + Vec64 hi, Vec64 lo) { + return Vec128(vcombine_u16(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, + Vec64 hi, Vec64 lo) { + return Vec128(vcombine_u32(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, + Vec64 hi, Vec64 lo) { + return Vec128(vcombine_u64(lo.raw, hi.raw)); +} + +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s8(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s16(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s32(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s64(lo.raw, hi.raw)); +} + +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_f32(lo.raw, hi.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_f64(lo.raw, hi.raw)); +} +#endif + +// < 64bit input, <= 64 bit result +template +HWY_API Vec128 Combine(Simd d, Vec128 hi, + Vec128 lo) { + // First double N (only lower halves will be used). + const Vec128 hi2(hi.raw); + const Vec128 lo2(lo.raw); + // Repartition to two unsigned lanes (each the size of the valid input). + const Simd, 2, 0> du; + return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); +} + +// ------------------------------ RearrangeToOddPlusEven (Combine) + +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + const Vec128 sum1) { + return Add(sum0, sum1); +} + +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + const Vec128 sum1) { +// vmlal_s16 multiplied the lower half into sum0 and upper into sum1. +#if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want + return Vec128(vpaddq_s32(sum0.raw, sum1.raw)); +#else + const Full128 d; + const Half d64; + const Vec64 hi( + vpadd_s32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw)); + const Vec64 lo( + vpadd_s32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw)); + return Combine(Full128(), hi, lo); +#endif +} + +HWY_API Vec64 RearrangeToOddPlusEven(const Vec64 sum0, + const Vec64 sum1) { + // vmlal_s16 multiplied the lower half into sum0 and upper into sum1. + return Vec64(vpadd_s32(sum0.raw, sum1.raw)); +} + +HWY_API Vec32 RearrangeToOddPlusEven(const Vec32 sum0, + const Vec32 sum1) { + // Only one widened sum per register, so add them for sum of odd and even. + return sum0 + sum1; +} + +// ------------------------------ ZeroExtendVector (Combine) + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ ConcatLowerLower + +// 64 or 128-bit input: just interleave +template +HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); +} + +namespace detail { +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveOdd, vtrn2, _, 2) +#else + +// vtrn returns a struct with even and odd result. +#define HWY_NEON_BUILD_TPL_HWY_TRN +#define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t +// Pass raw args so we can accept uint16x2 args, for which there is no +// corresponding uint16x2x2 return type. +#define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ + Raw128::type a, Raw128::type b +#define HWY_NEON_BUILD_ARG_HWY_TRN a, b + +// Cannot use UINT8 etc. type macros because the x2_t tuples are only defined +// for full and half vectors. +HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) +#endif +} // namespace detail + +// <= 32-bit input/output +template +HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[0])); +#endif +} + +// ------------------------------ ConcatUpperUpper + +// 64 or 128-bit input: just interleave +template +HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); +} + +// <= 32-bit input/output +template +HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[1])); +#endif +} + +// ------------------------------ ConcatLowerUpper (ShiftLeftBytes) + +// 64 or 128-bit input: extract from concatenated +template +HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, + Vec128 lo) { + return CombineShiftRightBytes(d, hi, lo); +} + +// <= 32-bit input/output +template +HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, + Vec128 lo) { + constexpr size_t kSize = N * sizeof(T); + const Repartition d8; + const Full64 d8x8; + const Full64 d64; + using V8x8 = VFromD; + const V8x8 hi8x8(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V8x8 lo8x8 = ShiftLeftBytes<8 - kSize>(V8x8(BitCast(d8, lo).raw)); + const V8x8 r = CombineShiftRightBytes<8 - kSize / 2>(d8x8, hi8x8, lo8x8); + // Back to original lane type, then shrink N. + return Vec128(BitCast(d64, r).raw); +} + +// ------------------------------ ConcatUpperLower + +// Works for all N. +template +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, + Vec128 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd (InterleaveUpper) + +namespace detail { +// There is no vuzpq_u64. +HWY_NEON_DEF_FUNCTION_UIF81632(ConcatEven, vuzp1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF81632(ConcatOdd, vuzp2, _, 2) +} // namespace detail + +// Full/half vector +template = 8>* = nullptr> +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + return detail::ConcatOdd(lo, hi); +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx1Lx1 = BitCast(dw2, ConcatOdd(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec128(BitCast(d2, ConcatEven(dw2, Hx1Lx1, Hx1Lx1)).raw); +} + +// Any type x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// Full/half vector +template = 8>* = nullptr> +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + return detail::ConcatEven(lo, hi); +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx0Lx0 = BitCast(dw2, ConcatEven(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec128(BitCast(d2, ConcatEven(dw2, Hx0Lx0, Hx0Lx0)).raw); +} + +// Any type x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); +#endif +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(Simd(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); +#endif +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(Simd(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + const Simd d; + const Repartition d8; + alignas(16) constexpr uint8_t kBytes[16] = { + ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, + ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, + ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, + ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, + ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, + ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, + ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, + ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, + }; + const auto vec = BitCast(d, Load(d8, kBytes)); + return IfThenElse(MaskFromVec(vec), b, a); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ------------------------------ ReorderDemote2To (OddEven) + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec128 ReorderDemote2To(Full128 d16, + Vec128 a, Vec128 b) { + const Vec64 a16(vqmovn_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovn_high_s32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovn_s32(b.raw)); + return Combine(d16, a16, b16); +#endif +} + +HWY_API Vec64 ReorderDemote2To(Full64 /*d16*/, + Vec64 a, Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, a, b); + return Vec64(vqmovn_s32(ab.raw)); +} + +HWY_API Vec32 ReorderDemote2To(Full32 /*d16*/, + Vec32 a, Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_s32(a.raw, b.raw)); + return Vec32(vqmovn_s32(Combine(d32, ab, ab).raw)); +} + +// ================================================== CRYPTO + +#if defined(__ARM_FEATURE_AES) || \ + (HWY_HAVE_RUNTIME_DISPATCH && HWY_ARCH_ARM_A64) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + // NOTE: it is important that AESE and AESMC be consecutive instructions so + // they can be fused. AESE includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128(vaesmcq_u8(vaeseq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { + return Vec128((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); +} + +HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { + return Vec128( + (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); +} + +#endif // __ARM_FEATURE_AES + +// ================================================== MISC + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + const auto v4 = detail::ConcatEven(v3, v3); + return LowerHalf(LowerHalf(LowerHalf(v4))); +} + +HWY_API Vec32 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +HWY_API Vec64 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +// ------------------------------ MulEven (ConcatEven) + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const Full128 d; + int32x4_t a_packed = ConcatEven(d, a, a).raw; + int32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const Full128 d; + uint32x4_t a_packed = ConcatEven(d, a, a).raw; + uint32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + int32x2_t a_packed = ConcatEven(d, a, a).raw; + int32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + uint32x2_t a_packed = ConcatEven(d, a, a).raw; + uint32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { + uint64_t hi; + uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 0), vgetq_lane_u64(b.raw, 0), &hi); + return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +} + +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { + uint64_t hi; + uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 1), vgetq_lane_u64(b.raw, 1), &hi); + return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +} + +// ------------------------------ TableLookupBytes (Combine, LowerHalf) + +// Both full +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d; + const Repartition d8; +#if HWY_ARCH_ARM_A64 + return BitCast(d, Vec128(vqtbl1q_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +#else + uint8x16_t table0 = BitCast(d8, bytes).raw; + uint8x8x2_t table; + table.val[0] = vget_low_u8(table0); + table.val[1] = vget_high_u8(table0); + uint8x16_t idx = BitCast(d8, from).raw; + uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); + return BitCast(d, Vec128(vcombine_u8(low, hi))); +#endif +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d_full; + const Vec64 from64(from.raw); + const auto idx_full = Combine(d_full, from64, from64); + const auto out_full = TableLookupBytes(bytes, idx_full); + return Vec128(LowerHalf(Half(), out_full).raw); +} + +// Partial table vector +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d_full; + return TableLookupBytes(Combine(d_full, bytes, bytes), from); +} + +// Partial both +template +HWY_API VFromD>> TableLookupBytes( + Vec128 bytes, Vec128 from) { + const Simd d; + const Simd d_idx; + const Repartition d_idx8; + // uint8x8 + const auto bytes8 = BitCast(Repartition(), bytes); + const auto from8 = BitCast(d_idx8, from); + const VFromD v8(vtbl1_u8(bytes8.raw, from8.raw)); + return BitCast(d_idx, v8); +} + +// For all vector widths; ARM anyway zeroes if >= 0x10. +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// full vectors +#if HWY_ARCH_ARM_A64 +#define HWY_NEON_BUILD_RET_REDUCTION(type, size) Vec128 +#define HWY_NEON_DEF_REDUCTION(type, size, name, prefix, infix, suffix, dup) \ + HWY_API HWY_NEON_BUILD_RET_REDUCTION(type, size) \ + name(hwy::SizeTag, const Vec128 v) { \ + return HWY_NEON_BUILD_RET_REDUCTION( \ + type, size)(dup##suffix(HWY_NEON_EVAL(prefix##infix##suffix, v.raw))); \ + } + +#define HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint8, 8, name, prefix, _, u8, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(uint8, 16, name, prefix##q, _, u8, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(uint16, 4, name, prefix, _, u16, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(uint16, 8, name, prefix##q, _, u16, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(uint32, 2, name, prefix, _, u32, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(uint32, 4, name, prefix##q, _, u32, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int8, 8, name, prefix, _, s8, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(int8, 16, name, prefix##q, _, s8, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int16, 4, name, prefix, _, s16, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(int16, 8, name, prefix##q, _, s16, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int32, 2, name, prefix, _, s32, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(int32, 4, name, prefix##q, _, s32, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(float32, 2, name, prefix, _, f32, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(float32, 4, name, prefix##q, _, f32, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(float64, 2, name, prefix##q, _, f64, vdupq_n_) + +HWY_NEON_DEF_REDUCTION_CORE_TYPES(MinOfLanes, vminv) +HWY_NEON_DEF_REDUCTION_CORE_TYPES(MaxOfLanes, vmaxv) + +// u64/s64 don't have horizontal min/max for some reason, but do have add. +#define HWY_NEON_DEF_REDUCTION_ALL_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint64, 2, name, prefix##q, _, u64, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int64, 2, name, prefix##q, _, s64, vdupq_n_) + +HWY_NEON_DEF_REDUCTION_ALL_TYPES(SumOfLanes, vaddv) + +#undef HWY_NEON_DEF_REDUCTION_ALL_TYPES +#undef HWY_NEON_DEF_REDUCTION_CORE_TYPES +#undef HWY_NEON_DEF_REDUCTION +#undef HWY_NEON_BUILD_RET_REDUCTION + +// Need some fallback implementations for [ui]64x2 and [ui]16x2. +#define HWY_IF_SUM_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 1 << 2) +#define HWY_IF_MINMAX_REDUCTION(T) \ + HWY_IF_LANE_SIZE_ONE_OF(T, (1 << 8) | (1 << 2)) + +#else +// u32/i32/f32: N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Shuffle2301(v10); +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Shuffle2301(v10)); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Shuffle2301(v10)); +} + +// ARMv7 version for everything except doubles. +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + uint32x4x2_t v0 = vuzpq_u32(v.raw, v.raw); + uint32x4_t c0 = vaddq_u32(v0.val[0], v0.val[1]); + uint32x4x2_t v1 = vuzpq_u32(c0, c0); + return Vec128(vaddq_u32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + int32x4x2_t v0 = vuzpq_s32(v.raw, v.raw); + int32x4_t c0 = vaddq_s32(v0.val[0], v0.val[1]); + int32x4x2_t v1 = vuzpq_s32(c0, c0); + return Vec128(vaddq_s32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + float32x4x2_t v0 = vuzpq_f32(v.raw, v.raw); + float32x4_t c0 = vaddq_f32(v0.val[0], v0.val[1]); + float32x4x2_t v1 = vuzpq_f32(c0, c0); + return Vec128(vaddq_f32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return v + Shuffle01(v); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return v + Shuffle01(v); +} + +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +#define HWY_NEON_BUILD_TYPE_T(type, size) type##x##size##_t +#define HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) Vec128 +#define HWY_NEON_DEF_PAIRWISE_REDUCTION(type, size, name, prefix, suffix) \ + HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) \ + name(hwy::SizeTag, const Vec128 v) { \ + HWY_NEON_BUILD_TYPE_T(type, size) tmp = prefix##_##suffix(v.raw, v.raw); \ + if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ + type, size)(HWY_NEON_EVAL(vdup##_lane_##suffix, tmp, 0)); \ + } +#define HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(type, size, half, name, prefix, \ + suffix) \ + HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) \ + name(hwy::SizeTag, const Vec128 v) { \ + HWY_NEON_BUILD_TYPE_T(type, half) tmp; \ + tmp = prefix##_##suffix(vget_high_##suffix(v.raw), \ + vget_low_##suffix(v.raw)); \ + if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 8) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + tmp = vdup_lane_##suffix(tmp, 0); \ + return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ + type, size)(HWY_NEON_EVAL(vcombine_##suffix, tmp, tmp)); \ + } + +#define HWY_NEON_DEF_PAIRWISE_REDUCTIONS(name, prefix) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint16, 4, name, prefix, u16) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint8, 8, name, prefix, u8) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int16, 4, name, prefix, s16) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int8, 8, name, prefix, s8) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint16, 8, 4, name, prefix, u16) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint8, 16, 8, name, prefix, u8) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int16, 8, 4, name, prefix, s16) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int8, 16, 8, name, prefix, s8) + +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(SumOfLanes, vpadd) +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(MinOfLanes, vpmin) +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(MaxOfLanes, vpmax) + +#undef HWY_NEON_DEF_PAIRWISE_REDUCTIONS +#undef HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION +#undef HWY_NEON_DEF_PAIRWISE_REDUCTION +#undef HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION +#undef HWY_NEON_BUILD_TYPE_T + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// Need fallback min/max implementations for [ui]64x2. +#define HWY_IF_SUM_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 0) +#define HWY_IF_MINMAX_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 1 << 8) + +#endif + +// [ui]16/[ui]64: N=2 -- special case for pairs of very small or large lanes +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v10) { + return v10 + Reverse2(Simd(), v10); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v10) { + return Min(v10, Reverse2(Simd(), v10)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v10) { + return Max(v10, Reverse2(Simd(), v10)); +} + +#undef HWY_IF_SUM_REDUCTION +#undef HWY_IF_MINMAX_REDUCTION + +} // namespace detail + +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// Helper function to set 64 bits and potentially return a smaller vector. The +// overload is required to call the q vs non-q intrinsics. Note that 8-bit +// LoadMaskBits only requires 16 bits, but 64 avoids casting. +template +HWY_INLINE Vec128 Set64(Simd /* tag */, uint64_t mask_bits) { + const auto v64 = Vec64(vdup_n_u64(mask_bits)); + return Vec128(BitCast(Full64(), v64).raw); +} +template +HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { + return BitCast(d, Vec128(vdupq_n_u64(mask_bits))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const auto vmask_bits = Set64(du, mask_bits); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vmask_bits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns mask[i]? 0xF : 0 in each nibble. This is more efficient than +// BitsFromMask for use in (partial) CountTrue, FindFirstTrue and AllFalse. +template +HWY_INLINE uint64_t NibblesFromMask(const Full128 d, Mask128 mask) { + const Full128 du16; + const Vec128 vu16 = BitCast(du16, VecFromMask(d, mask)); + const Vec64 nib(vshrn_n_u16(vu16.raw, 4)); + return GetLane(BitCast(Full64(), nib)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(const Full64 d, Mask64 mask) { + // There is no vshrn_n_u16 for uint16x4, so zero-extend. + const Twice d2; + const Vec128 v128 = ZeroExtendVector(d2, VecFromMask(d, mask)); + // No need to mask, upper half is zero thanks to ZeroExtendVector. + return NibblesFromMask(d2, MaskFromVec(v128)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(Simd /*d*/, Mask128 mask) { + const Mask64 mask64(mask.raw); + const uint64_t nib = NibblesFromMask(Full64(), mask64); + // Clear nibbles from upper half of 64-bits + constexpr size_t kBytes = sizeof(T) * N; + return nib & ((1ull << (kBytes * 4)) - 1); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint8_t kSliceLanes[16] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, + }; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(Full128(), mask)) & Load(du, kSliceLanes); + +#if HWY_ARCH_ARM_A64 + // Can't vaddv - we need two separate bytes (16 bits). + const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); + const uint8x8_t x4 = vpadd_u8(x2, x2); + const uint8x8_t x8 = vpadd_u8(x4, x4); + return vget_lane_u64(vreinterpret_u64_u8(x8), 0); +#else + // Don't have vpaddq, so keep doubling lane size. + const uint16x8_t x2 = vpaddlq_u8(values.raw); + const uint32x4_t x4 = vpaddlq_u16(x2); + const uint64x2_t x8 = vpaddlq_u32(x4); + return (vgetq_lane_u64(x8, 1) << 8) | vgetq_lane_u64(x8, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; + +#if HWY_ARCH_ARM_A64 + return vaddv_u8(values.raw); +#else + const uint16x4_t x2 = vpaddl_u8(values.raw); + const uint32x2_t x4 = vpaddl_u16(x2); + const uint64x1_t x8 = vpaddl_u32(x4); + return vget_lane_u64(x8, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint16_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u16(values.raw); +#else + const uint32x4_t x2 = vpaddlq_u16(values.raw); + const uint64x2_t x4 = vpaddlq_u32(x2); + return vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u16(values.raw); +#else + const uint32x2_t x2 = vpaddl_u16(values.raw); + const uint64x1_t x4 = vpaddl_u32(x2); + return vget_lane_u64(x4, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u32(values.raw); +#else + const uint64x2_t x2 = vpaddlq_u32(values.raw); + return vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint32_t kSliceLanes[2] = {1, 2}; + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u32(values.raw); +#else + const uint64x1_t x2 = vpaddl_u32(values.raw); + return vget_lane_u64(x2, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + alignas(16) constexpr uint64_t kSliceLanes[2] = {1, 2}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, m)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u64(values.raw); +#else + return vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 m) { + const Full64 d; + const Full64 du; + const Vec64 values = BitCast(du, VecFromMask(d, m)) & Set(du, 1); + return vget_lane_u64(values.raw, 0); +} + +// Returns the lowest N for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) >= 8) ? bits : (bits & ((1ull << N) - 1)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +// Returns number of lanes whose mask is set. +// +// Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op +// ("vsubv"). ANDing with 1 would work but requires a constant. Negating also +// changes each lane to 1 (if mask set) or 0. +// NOTE: PopCount also operates on vectors, so we still have to do horizontal +// sums separately. We specialize CountTrue for full vectors (negating instead +// of PopCount because it avoids an extra shift), and use PopCount of +// NibblesFromMask for partial vectors. + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { + const Full128 di; + const int8x16_t ones = + vnegq_s8(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s8(ones)); +#else + const int16x8_t x2 = vpaddlq_s8(ones); + const int32x4_t x4 = vpaddlq_s16(x2); + const int64x2_t x8 = vpaddlq_s32(x4); + return static_cast(vgetq_lane_s64(x8, 0) + vgetq_lane_s64(x8, 1)); +#endif +} +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, const Mask128 mask) { + const Full128 di; + const int16x8_t ones = + vnegq_s16(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s16(ones)); +#else + const int32x4_t x2 = vpaddlq_s16(ones); + const int64x2_t x4 = vpaddlq_s32(x2); + return static_cast(vgetq_lane_s64(x4, 0) + vgetq_lane_s64(x4, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { + const Full128 di; + const int32x4_t ones = + vnegq_s32(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s32(ones)); +#else + const int64x2_t x2 = vpaddlq_s32(ones); + return static_cast(vgetq_lane_s64(x2, 0) + vgetq_lane_s64(x2, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 mask) { +#if HWY_ARCH_ARM_A64 + const Full128 di; + const int64x2_t ones = + vnegq_s64(BitCast(di, VecFromMask(Full128(), mask)).raw); + return static_cast(vaddvq_s64(ones)); +#else + const Full128 du; + const auto mask_u = VecFromMask(du, RebindMask(du, mask)); + const uint64x2_t ones = vshrq_n_u64(mask_u.raw, 63); + return static_cast(vgetq_lane_u64(ones, 0) + vgetq_lane_u64(ones, 1)); +#endif +} + +} // namespace detail + +// Full +template +HWY_API size_t CountTrue(Full128 /* tag */, const Mask128 mask) { + return detail::CountTrue(hwy::SizeTag(), mask); +} + +// Partial +template +HWY_API size_t CountTrue(Simd d, const Mask128 mask) { + constexpr int kDiv = 4 * sizeof(T); + return PopCount(detail::NibblesFromMask(d, mask)) / kDiv; +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd d, + const Mask128 mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(T); + return Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv; +} + +template +HWY_API intptr_t FindFirstTrue(const Simd d, + const Mask128 mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr int kDiv = 4 * sizeof(T); + return static_cast(Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, + uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API bool AllFalse(const Simd d, const Mask128 m) { + return detail::NibblesFromMask(d, m) == 0; +} + +// Full +template +HWY_API bool AllTrue(const Full128 d, const Mask128 m) { + return detail::NibblesFromMask(d, m) == ~0ull; +} +// Partial +template +HWY_API bool AllTrue(const Simd d, const Mask128 m) { + constexpr size_t kBytes = sizeof(T) * N; + return detail::NibblesFromMask(d, m) == (1ull << (kBytes * 4)) - 1; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +namespace detail { + +// Load 8 bytes, replicate into upper half so ZipLower can use the lower half. +HWY_INLINE Vec128 Load8Bytes(Full128 /*d*/, + const uint8_t* bytes) { + return Vec128(vreinterpretq_u8_u64( + vld1q_dup_u64(reinterpret_cast(bytes)))); +} + +// Load 8 bytes and return half-reg with N <= 8 bytes. +template +HWY_INLINE Vec128 Load8Bytes(Simd d, + const uint8_t* bytes) { + return Load(d, bytes); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#endif + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. +template +HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromBits(hwy::SizeTag(), mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromNotBits(hwy::SizeTag(), mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Simd d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNot(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits + +template +HWY_INLINE Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + using TU = TFromD; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Mask128 store_mask = RebindMask(d, FirstN(du, count)); + const Vec128 compressed = detail::Compress(BitCast(du, v), mask_bits); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_LOAD_INT(T, N) HWY_IF_GE64(T, N) +#define HWY_NEON_DEF_FUNCTION_LOAD_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64 +#define HWY_IF_LOAD_INT(T, N) \ + hwy::EnableIf= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +// Must return raw tuple because Tuple2 lack a ctor, and we cannot use +// brace-initialization in HWY_NEON_DEF_FUNCTION because some functions return +// void. +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple2().raw) +// Tuple tag arg allows overloading (cannot just overload on return type) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple2 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved2, vld2, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple3().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple3 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved3, vld3, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple4().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple4 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#undef HWY_NEON_DEF_FUNCTION_LOAD_INT +#undef HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#undef HWY_NEON_BUILD_ARG_HWY_LOAD_INT +} // namespace detail + +template +HWY_API void LoadInterleaved2(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1) { + auto raw = detail::LoadInterleaved2(unaligned, detail::Tuple2()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); +} + +// <= 32 bits: avoid loading more than N bytes by copying to buffer +template +HWY_API void LoadInterleaved2(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1) { + // The smallest vector registers are 64-bits and we want space for two. + alignas(16) T buf[2 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved2(buf, detail::Tuple2()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void LoadInterleaved2(Full128 d, T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1) { + const Half dh; + VFromD v00, v10, v01, v11; + LoadInterleaved2(dh, unaligned, v00, v10); + LoadInterleaved2(dh, unaligned + 2, v01, v11); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved3 + +template +HWY_API void LoadInterleaved3(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2) { + auto raw = detail::LoadInterleaved3(unaligned, detail::Tuple3()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void LoadInterleaved3(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2) { + // The smallest vector registers are 64-bits and we want space for three. + alignas(16) T buf[3 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved3(buf, detail::Tuple3()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void LoadInterleaved3(Full128 d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2) { + const Half dh; + VFromD v00, v10, v20, v01, v11, v21; + LoadInterleaved3(dh, unaligned, v00, v10, v20); + LoadInterleaved3(dh, unaligned + 3, v01, v11, v21); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved4 + +template +HWY_API void LoadInterleaved4(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2, + Vec128& v3) { + auto raw = detail::LoadInterleaved4(unaligned, detail::Tuple4()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); + v3 = Vec128(raw.val[3]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void LoadInterleaved4(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2, + Vec128& v3) { + alignas(16) T buf[4 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved4(buf, detail::Tuple4()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); + v3 = Vec128(raw.val[3]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void LoadInterleaved4(Full128 d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2, + Vec128& v3) { + const Half dh; + VFromD v00, v10, v20, v30, v01, v11, v21, v31; + LoadInterleaved4(dh, unaligned, v00, v10, v20, v30); + LoadInterleaved4(dh, unaligned + 4, v01, v11, v21, v31); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); + v3 = Combine(d, v31, v30); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_LOAD_INT + +// ------------------------------ StoreInterleaved2 + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_STORE_INT +#define HWY_NEON_BUILD_RET_HWY_STORE_INT(type, size) void +#define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_STORE_INT(T, N) HWY_IF_GE64(T, N) +#define HWY_NEON_DEF_FUNCTION_STORE_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64 +#define HWY_IF_STORE_INT(T, N) \ + hwy::EnableIf= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple2 tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved2, vst2, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple3 tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved3, vst3, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple4 tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved4, vst4, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#undef HWY_NEON_DEF_FUNCTION_STORE_INT +#undef HWY_NEON_BUILD_TPL_HWY_STORE_INT +#undef HWY_NEON_BUILD_RET_HWY_STORE_INT +#undef HWY_NEON_BUILD_ARG_HWY_STORE_INT +} // namespace detail + +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[2 * 8 / sizeof(T)]; + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, buf); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Full128 d, T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, unaligned); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, unaligned + 2); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved3 + +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[3 * 8 / sizeof(T)]; + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, buf); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Full128 d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, + unaligned); + StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, + unaligned + 3); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved4 + +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[4 * 8 / sizeof(T)]; + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, buf); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Full128 d, T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), + LowerHalf(dh, v3), dh, unaligned); + StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), + UpperHalf(dh, v3), dh, unaligned + 4); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_STORE_INT + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128 eqHL = Eq(a, b); + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const Vec128 ltLx = DupEven(ltHL); + const Vec128 outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE Mask128 Lt128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE Mask128 Eq128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE Mask128 Eq128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE Mask128 Ne128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE Mask128 Ne128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +namespace detail { // for code folding +#if HWY_ARCH_ARM_V7 +#undef vuzp1_s8 +#undef vuzp1_u8 +#undef vuzp1_s16 +#undef vuzp1_u16 +#undef vuzp1_s32 +#undef vuzp1_u32 +#undef vuzp1_f32 +#undef vuzp1q_s8 +#undef vuzp1q_u8 +#undef vuzp1q_s16 +#undef vuzp1q_u16 +#undef vuzp1q_s32 +#undef vuzp1q_u32 +#undef vuzp1q_f32 +#undef vuzp2_s8 +#undef vuzp2_u8 +#undef vuzp2_s16 +#undef vuzp2_u16 +#undef vuzp2_s32 +#undef vuzp2_u32 +#undef vuzp2_f32 +#undef vuzp2q_s8 +#undef vuzp2q_u8 +#undef vuzp2q_s16 +#undef vuzp2q_u16 +#undef vuzp2q_s32 +#undef vuzp2q_u32 +#undef vuzp2q_f32 +#undef vzip1_s8 +#undef vzip1_u8 +#undef vzip1_s16 +#undef vzip1_u16 +#undef vzip1_s32 +#undef vzip1_u32 +#undef vzip1_f32 +#undef vzip1q_s8 +#undef vzip1q_u8 +#undef vzip1q_s16 +#undef vzip1q_u16 +#undef vzip1q_s32 +#undef vzip1q_u32 +#undef vzip1q_f32 +#undef vzip2_s8 +#undef vzip2_u8 +#undef vzip2_s16 +#undef vzip2_u16 +#undef vzip2_s32 +#undef vzip2_u32 +#undef vzip2_f32 +#undef vzip2q_s8 +#undef vzip2q_u8 +#undef vzip2q_s16 +#undef vzip2q_u16 +#undef vzip2q_s32 +#undef vzip2q_u32 +#undef vzip2q_f32 +#endif + +#undef HWY_NEON_BUILD_ARG_1 +#undef HWY_NEON_BUILD_ARG_2 +#undef HWY_NEON_BUILD_ARG_3 +#undef HWY_NEON_BUILD_PARAM_1 +#undef HWY_NEON_BUILD_PARAM_2 +#undef HWY_NEON_BUILD_PARAM_3 +#undef HWY_NEON_BUILD_RET_1 +#undef HWY_NEON_BUILD_RET_2 +#undef HWY_NEON_BUILD_RET_3 +#undef HWY_NEON_BUILD_TPL_1 +#undef HWY_NEON_BUILD_TPL_2 +#undef HWY_NEON_BUILD_TPL_3 +#undef HWY_NEON_DEF_FUNCTION +#undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS +#undef HWY_NEON_DEF_FUNCTION_ALL_TYPES +#undef HWY_NEON_DEF_FUNCTION_FLOAT_64 +#undef HWY_NEON_DEF_FUNCTION_FULL_UI +#undef HWY_NEON_DEF_FUNCTION_INT_16 +#undef HWY_NEON_DEF_FUNCTION_INT_32 +#undef HWY_NEON_DEF_FUNCTION_INT_8 +#undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_INTS +#undef HWY_NEON_DEF_FUNCTION_INTS_UINTS +#undef HWY_NEON_DEF_FUNCTION_TPL +#undef HWY_NEON_DEF_FUNCTION_UIF81632 +#undef HWY_NEON_DEF_FUNCTION_UINT_16 +#undef HWY_NEON_DEF_FUNCTION_UINT_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_8 +#undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINTS +#undef HWY_NEON_EVAL +} // namespace detail + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/arm_sve-inl.h b/third_party/highway/hwy/ops/arm_sve-inl.h new file mode 100644 index 0000000000..5b83017172 --- /dev/null +++ b/third_party/highway/hwy/ops/arm_sve-inl.h @@ -0,0 +1,3186 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM SVE[2] vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +// If running on hardware whose vector length is known to be a power of two, we +// can skip fixups for non-power of two sizes. +#undef HWY_SVE_IS_POW2 +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#define HWY_SVE_IS_POW2 1 +#else +#define HWY_SVE_IS_POW2 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// Unsigned: +#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 64, 32, NAME, OP) + +// Signed: +#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) + +// Float: +#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 16, 16, NAME, OP) +#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 64, 32, NAME, OP) + +// For all element sizes: +#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories for a given element size: +#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories: +#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +// Assemble types for use in x-macros +#define HWY_SVE_T(BASE, BITS) BASE##BITS##_t +#define HWY_SVE_D(BASE, BITS, N, POW2) Simd +#define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t + +} // namespace detail + +#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using type = ScalableTag; \ + }; + +HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) +#undef HWY_SPECIALIZE + +// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX +// instructions, and we anyway only use it when the predicate is ptrue. + +// vector = f(vector), e.g. Not +#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +// vector = f(vector, scalar), e.g. detail::AddN +#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS(a, b, c); \ + } + +// ------------------------------ Lanes + +namespace detail { + +// Returns actual lanes of a hardware vector without rounding to a power of two. +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<1> /* tag */) { + return svcntb_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<2> /* tag */) { + return svcnth_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<4> /* tag */) { + return svcntw_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<8> /* tag */) { + return svcntd_pat(SV_ALL); +} + +// All-true mask from a macro +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) + +#if HWY_SVE_IS_POW2 +#define HWY_SVE_PTRUE(BITS) HWY_SVE_ALL_PTRUE(BITS) +#else +#define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) + +// Returns actual lanes of a hardware vector, rounded down to a power of two. +template +HWY_INLINE size_t HardwareLanes() { + return svcntb_pat(SV_POW2); +} +template +HWY_INLINE size_t HardwareLanes() { + return svcnth_pat(SV_POW2); +} +template +HWY_INLINE size_t HardwareLanes() { + return svcntw_pat(SV_POW2); +} +template +HWY_INLINE size_t HardwareLanes() { + return svcntd_pat(SV_POW2); +} + +#endif // HWY_SVE_IS_POW2 + +} // namespace detail + +// Returns actual number of lanes after capping by N and shifting. May return 0 +// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). +#if HWY_TARGET == HWY_SVE_256 +template +HWY_API constexpr size_t Lanes(Simd /* d */) { + return HWY_MIN(detail::ScaleByPower(32 / sizeof(T), kPow2), N); +} +#elif HWY_TARGET == HWY_SVE2_128 +template +HWY_API constexpr size_t Lanes(Simd /* d */) { + return HWY_MIN(detail::ScaleByPower(16 / sizeof(T), kPow2), N); +} +#else +template +HWY_API size_t Lanes(Simd d) { + const size_t actual = detail::HardwareLanes(); + // Common case of full vectors: avoid any extra instructions. + if (detail::IsFull(d)) return actual; + return HWY_MIN(detail::ScaleByPower(actual, kPow2), N); +} +#endif // HWY_TARGET + +// ================================================== MASK INIT + +// One mask bit per byte; only the one belonging to the lowest byte is valid. + +// ------------------------------ FirstN +#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ + const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ + return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(limit)); \ + } +HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) +#undef HWY_SVE_FIRSTN + +template +using MFromD = decltype(FirstN(D(), 0)); + +namespace detail { + +#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_PTRUE(BITS); \ + } \ + template \ + HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_ALL_PTRUE(BITS); \ + } + +HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true +#undef HWY_SVE_WRAP_PTRUE + +HWY_API svbool_t PFalse() { return svpfalse_b(); } + +// Returns all-true if d is HWY_FULL or FirstN(N) after capping N. +// +// This is used in functions that load/store memory; other functions (e.g. +// arithmetic) can ignore d and use PTrue instead. +template +svbool_t MakeMask(D d) { + return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d)); +} + +} // namespace detail + +// ================================================== INIT + +// ------------------------------ Set +// vector = f(d, scalar), e.g. Set +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) arg) { \ + return sv##OP##_##CHAR##BITS(arg); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) +#undef HWY_SVE_SET + +// Required for Zero and VFromD +template +svuint16_t Set(Simd d, bfloat16_t arg) { + return Set(RebindToUnsigned(), arg.bits); +} + +template +using VFromD = decltype(Set(D(), TFromD())); + +// ------------------------------ Zero + +template +VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) + +// ------------------------------ BitCast + +namespace detail { + +// u8: no change +#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } + +// All other types +#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u8_##CHAR##BITS(v); \ + } \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ + return sv##OP##_##CHAR##BITS##_u8(v); \ + } + +HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) +HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) + +#undef HWY_SVE_CAST_NOP +#undef HWY_SVE_CAST + +template +HWY_INLINE svuint16_t BitCastFromByte(Simd /* d */, + svuint8_t v) { + return BitCastFromByte(Simd(), v); +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ================================================== LOGICAL + +// detail::*N() functions accept a scalar argument to avoid extra Set(). + +// ------------------------------ Not +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT + +// ------------------------------ And + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) + +template +HWY_API V And(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, And(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Or + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) + +template +HWY_API V Or(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) + +template +HWY_API V Xor(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ AndNot + +namespace detail { +#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) +#undef HWY_SVE_RETV_ARGPVN_SWAP +} // namespace detail + +#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) +#undef HWY_SVE_RETV_ARGPVV_SWAP + +template +HWY_API V AndNot(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor3 + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3) + +template +HWY_API V Xor3(const V x1, const V x2, const V x3) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#else +template +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} +#endif + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +// Need to return original type instead of unsigned. +#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return BitCast(DFromV(), \ + sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) +#undef HWY_SVE_POPCNT + +// ================================================== SIGN + +// ------------------------------ Neg +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) + +// ------------------------------ Abs +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) + +// ------------------------------ CopySign[ToAbs] + +template +HWY_API V CopySign(const V magn, const V sign) { + const auto msb = SignBit(DFromV()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + const auto msb = SignBit(DFromV()); + return Or(abs, And(msb, sign)); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) + +// ------------------------------ Sub + +namespace detail { +// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. +#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) +#undef HWY_SVE_RETV_ARGPVN_MASK +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) + +// ------------------------------ SumsOf8 +HWY_API svuint64_t SumsOf8(const svuint8_t v) { + const ScalableTag du32; + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + + const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); + // Compute pairwise sum of u32 and extend to u64. + // TODO(janwas): on SVE2, we can instead use svaddp. + const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) + const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); + return Add(hi, lo); +} + +// ------------------------------ SaturatedAdd + +HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) + +// ------------------------------ SaturatedSub + +HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) + +// ------------------------------ AbsDiff +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) + +// ------------------------------ ShiftLeft[Same] + +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) + +// ------------------------------ ShiftRight[Same] + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) + +#undef HWY_SVE_SHIFT_N + +// ------------------------------ RotateRight + +// TODO(janwas): svxar on SVE2 +template +HWY_API V RotateRight(const V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shl/r + +#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToUnsigned> du; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + BitCast(du, bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) + +#undef HWY_SVE_SHIFT + +// ------------------------------ Min/Max + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) +} // namespace detail + +// ------------------------------ Mul +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, Mul, mul) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_RETV_ARGPVV, Mul, mul) + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// ------------------------------ MulHigh +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +// Not part of API, used internally: +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) + +// ------------------------------ MulFixedPoint15 +HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + return svqrdmulh_s16(a, b); +#else + const DFromV d; + const RebindToUnsigned du; + + const svuint16_t lo = BitCast(du, Mul(a, b)); + const svint16_t hi = MulHigh(a, b); + // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must + // carry that into the result. Instead isolate the top two bits because only + // they can influence the result. + const svuint16_t lo_top2 = ShiftRight<14>(lo); + // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. + const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); + return Add(Add(hi, hi), BitCast(d, rounding)); +#endif +} + +// ------------------------------ Div +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) + +// ------------------------------ ApproximateReciprocal +HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) + +// ------------------------------ Sqrt +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) + +// ------------------------------ ApproximateReciprocalSqrt +HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) + +// ------------------------------ MulAdd +#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulAdd, mad) + +// ------------------------------ NegMulAdd +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulAdd, msb) + +// ------------------------------ MulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) + +// ------------------------------ NegMulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) + +#undef HWY_SVE_FMA + +// ------------------------------ Round etc. + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) + +// ================================================== MASK + +// ------------------------------ RebindMask +template +HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { + return mask; +} + +// ------------------------------ Mask logical + +HWY_API svbool_t Not(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Per ARM, such bits are ignored. + return svnot_b_z(HWY_SVE_PTRUE(8), m); +} +HWY_API svbool_t And(svbool_t a, svbool_t b) { + return svand_b_z(b, b, a); // same order as AndNot for consistency +} +HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { + return svbic_b_z(b, b, a); // reversed order like NEON +} +HWY_API svbool_t Or(svbool_t a, svbool_t b) { + return svsel_b(a, a, b); // a ? true : b +} +HWY_API svbool_t Xor(svbool_t a, svbool_t b) { + return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. +} + +HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { + return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. +} + +// ------------------------------ CountTrue + +#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ + return sv##OP##_b##BITS(detail::MakeMask(d), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) +#undef HWY_SVE_COUNT_TRUE + +// For 16-bit Compress: full vector, not limited to SV_POW2. +namespace detail { + +#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ + return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) +#undef HWY_SVE_COUNT_TRUE_FULL + +} // namespace detail + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, svbool_t m) { + return !svptest_any(detail::MakeMask(d), m); +} + +// ------------------------------ AllTrue +template +HWY_API bool AllTrue(D d, svbool_t m) { + return CountTrue(d, m) == Lanes(d); +} + +// ------------------------------ FindFirstTrue +template +HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast( + CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); +} + +// ------------------------------ FindKnownFirstTrue +template +HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { + return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); +} + +// ------------------------------ IfThenElse +#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(m, yes, no); \ + } + +HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +#undef HWY_SVE_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template +HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse +template +HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { + return IfThenElse(mask, Zero(DFromV()), no); +} + +// ================================================== COMPARE + +// mask = f(vector, vector) +#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } + +// ------------------------------ Eq +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) +} // namespace detail + +// ------------------------------ Ne +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) +} // namespace detail + +// ------------------------------ Lt +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) +} // namespace detail + +// ------------------------------ Le +HWY_SVE_FOREACH_F(HWY_SVE_COMPARE, Le, cmple) + +#undef HWY_SVE_COMPARE +#undef HWY_SVE_COMPARE_N + +// ------------------------------ Gt/Ge (swapped order) +template +HWY_API svbool_t Gt(const V a, const V b) { + return Lt(b, a); +} +template +HWY_API svbool_t Ge(const V a, const V b) { + return Le(b, a); +} + +// ------------------------------ TestBit +template +HWY_API svbool_t TestBit(const V a, const V bit) { + return detail::NeN(And(a, bit), 0); +} + +// ------------------------------ MaskFromVec (Ne) +template +HWY_API svbool_t MaskFromVec(const V v) { + return detail::NeN(v, static_cast>(0)); +} + +// ------------------------------ VecFromMask +template +HWY_API VFromD VecFromMask(const D d, svbool_t mask) { + const RebindToSigned di; + // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which + // requires an extra instruction plus M0 pipeline. + return BitCast(d, IfThenElseZero(mask, Set(di, -1))); +} + +// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +#define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ + HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(yes, no, mask); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) +#undef HWY_SVE_IF_VEC + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#else + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +// ------------------------------ Floating-point classification (Ne) + +template +HWY_API svbool_t IsNaN(const V v) { + return Ne(v, v); // could also use cmpuo +} + +template +HWY_API svbool_t IsInf(const V v) { + using T = TFromV; + const DFromV d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqN(Add(vi, vi), hwy::MaxExponentTimes2())); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API svbool_t IsFinite(const V v) { + using T = TFromV; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField())); +} + +// ================================================== MEMORY + +// ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream + +#define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(detail::MakeMask(d), p); \ + } + +#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(m, p); \ + } + +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \ + } + +#define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), p, v); \ + } + +#define HWY_SVE_BLENDED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(m, p, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1) +HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1) +HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDup128, ld1rq) +HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1) +HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1) +HWY_SVE_FOREACH(HWY_SVE_BLENDED_STORE, BlendedStore, st1) + +#undef HWY_SVE_LOAD +#undef HWY_SVE_MASKED_LOAD +#undef HWY_SVE_LOAD_DUP128 +#undef HWY_SVE_STORE +#undef HWY_SVE_BLENDED_STORE + +// BF16 is the same as svuint16_t because BF16 is optional before v8.6. +template +HWY_API svuint16_t Load(Simd d, + const bfloat16_t* HWY_RESTRICT p) { + return Load(RebindToUnsigned(), + reinterpret_cast(p)); +} + +template +HWY_API void Store(svuint16_t v, Simd d, + bfloat16_t* HWY_RESTRICT p) { + Store(v, RebindToUnsigned(), + reinterpret_cast(p)); +} + +// ------------------------------ Load/StoreU + +// SVE only requires lane alignment, not natural alignment of the entire +// vector. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ ScatterOffset/Index + +#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ + v); \ + } + +#define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, HWY_SVE_V(int, BITS) index) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, index, v); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter) +#undef HWY_SVE_SCATTER_OFFSET +#undef HWY_SVE_SCATTER_INDEX + +// ------------------------------ GatherOffset/Index + +#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ + offset); \ + } +#define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) index) { \ + return sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, \ + index); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather) +#undef HWY_SVE_GATHER_OFFSET +#undef HWY_SVE_GATHER_INDEX + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ + const sv##BASE##BITS##x2_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget2(tuple, 0); \ + v1 = svget2(tuple, 1); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) + +#undef HWY_SVE_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2) { \ + const sv##BASE##BITS##x3_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget3(tuple, 0); \ + v1 = svget3(tuple, 1); \ + v2 = svget3(tuple, 2); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) + +#undef HWY_SVE_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ + const sv##BASE##BITS##x4_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget4(tuple, 0); \ + v1 = svget4(tuple, 1); \ + v2 = svget4(tuple, 2); \ + v3 = svget4(tuple, 3); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) + +#undef HWY_SVE_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x2_t tuple = svcreate2##_##CHAR##BITS(v0, v1); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, tuple); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) + +#undef HWY_SVE_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, triple); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) + +#undef HWY_SVE_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x4_t quad = \ + svcreate4##_##CHAR##BITS(v0, v1, v2, v3); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, quad); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) + +#undef HWY_SVE_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// Same sign +#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) + +// 2x +template +HWY_API svuint32_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} + +// Sign change +template +HWY_API svint16_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RebindToUnsigned du; + return BitCast(dto, PromoteTo(du, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svuint16_t vfrom) { + const RebindToUnsigned du; + return BitCast(dto, PromoteTo(du, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svuint8_t vfrom) { + const Repartition> du16; + const Repartition di16; + return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom))); +} + +// ------------------------------ PromoteTo F + +// Unlike Highway's ZipLower, this returns the same type. +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) +} // namespace detail + +template +HWY_API svfloat32_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f32_f16_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svint32_t v) { + const svint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_s32_x(detail::PTrue(Simd()), vv); +} + +// For 16-bit Compress +namespace detail { +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +#undef HWY_SVE_PROMOTE_TO + +template +HWY_API svfloat32_t PromoteUpperTo(Simd df, svfloat16_t v) { + const RebindToUnsigned du; + const RepartitionToNarrow dn; + return BitCast(df, PromoteUpperTo(du, BitCast(dn, v))); +} + +} // namespace detail + +// ------------------------------ DemoteTo U + +namespace detail { + +// Saturates unsigned vectors to half/quarter-width TN. +template +VU SaturateU(VU v) { + return detail::MinN(v, static_cast>(LimitsMax())); +} + +// Saturates unsigned vectors to half/quarter-width TN. +template +VI SaturateI(VI v) { + return detail::MinN(detail::MaxN(v, LimitsMin()), LimitsMax()); +} + +} // namespace detail + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint16_t v) { + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint8_t vn = BitCast(dn, detail::SaturateU(clamped)); + return svuzp1_u8(vn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint32_t v) { + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint16_t vn = BitCast(dn, detail::SaturateU(clamped)); + return svuzp1_u16(vn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint32_t v) { + const DFromV di; + const RebindToUnsigned du; + const RepartitionToNarrow d2; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and quarter the width. + const svuint16_t cast16 = BitCast(d2, detail::SaturateU(clamped)); + const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); + return svuzp1_u8(x2, x2); +} + +HWY_API svuint8_t U8FromU32(const svuint32_t v) { + const DFromV du32; + const RepartitionToNarrow du16; + const RepartitionToNarrow du8; + + const svuint16_t cast16 = BitCast(du16, v); + const svuint16_t x2 = svuzp1_u16(cast16, cast16); + const svuint8_t cast8 = BitCast(du8, x2); + return svuzp1_u8(cast8, cast8); +} + +// ------------------------------ Truncations + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + const svuint8_t v3 = svuzp1_u8(v2, v2); + return svuzp1_u8(v3, v3); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + const svuint16_t v2 = svuzp1_u16(v1, v1); + return svuzp1_u16(v2, v2); +} + +template +HWY_API svuint32_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint32_t v1 = BitCast(d, v); + return svuzp1_u32(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + return svuzp1_u8(v2, v2); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + return svuzp1_u16(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint16_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + return svuzp1_u8(v1, v1); +} + +// ------------------------------ DemoteTo I + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint16_t v) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); +#else + using TN = TFromD; + const svint8_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s8(vn, vn); +} + +template +HWY_API svint16_t DemoteTo(Simd dn, const svint32_t v) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); +#else + using TN = TFromD; + const svint16_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s16(vn, vn); +} + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { + const RepartitionToWide d2; +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); +#else + using TN = TFromD; + const svint16_t cast16 = BitCast(d2, detail::SaturateI(v)); +#endif + const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); + return BitCast(dn, svuzp1_s8(v2, v2)); +} + +// ------------------------------ ConcatEven/ConcatOdd + +// WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the +// full vector length, not rounded down to a power of two as we require). +namespace detail { + +#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) +#endif +#undef HWY_SVE_CONCAT_EVERY_SECOND + +// Used to slide up / shift whole register left; mask indicates which range +// to take from lo, and the rest is filled from hi starting at its lowest. +#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) +#undef HWY_SVE_SPLICE + +} // namespace detail + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + (void)d; + return detail::ConcatOddFull(hi, lo); +#else + const VFromD hi_odd = detail::ConcatOddFull(hi, hi); + const VFromD lo_odd = detail::ConcatOddFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + (void)d; + return detail::ConcatEvenFull(hi, lo); +#else + const VFromD hi_odd = detail::ConcatEvenFull(hi, hi); + const VFromD lo_odd = detail::ConcatEvenFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +#endif +} + +// ------------------------------ DemoteTo F + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svuint16_t DemoteTo(Simd /* d */, svfloat32_t v) { + const svuint16_t in_even = BitCast(ScalableTag(), v); + return detail::ConcatOddFull(in_even, in_even); // lower half +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +// ------------------------------ ConvertTo F + +#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* signed integers */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* unsigned integers */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } + +// API only requires f32 but we provide f64 for use by Iota. +HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) +#undef HWY_SVE_CONVERT + +// ------------------------------ NearestInt (Round, ConvertTo) +template >> +HWY_API VFromD NearestInt(VF v) { + // No single instruction, round then truncate. + return ConvertTo(DI(), Round(v)); +} + +// ------------------------------ Iota (Add, ConvertTo) + +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) first) { \ + return sv##OP##_##CHAR##BITS(first, 1); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) +#undef HWY_SVE_IOTA + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToSigned di; + return detail::AddN(ConvertTo(d, Iota(di, 0)), first); +} + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipLowerSame(a, b); +#else + // Move lower halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatEvenFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper + +// Only use zip2 if vector are a powers of two, otherwise getting the actual +// "upper half" requires MaskUpperHalf. +#if HWY_TARGET == HWY_SVE2_128 +namespace detail { +// Unlike Highway's ZipUpper, this returns the same type. +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) +} // namespace detail +#endif + +// Full vector: guaranteed to have at least one block +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipUpperSame(a, b); +#else + // Move upper halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatOddFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +// Capped/fraction: need runtime check +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { + // Less than one block: treat as capped + if (Lanes(d) * sizeof(TFromD) < 16) { + const Half d2; + return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); + } + return InterleaveUpper(DFromV(), a, b); +} + +// ================================================== COMBINE + +namespace detail { + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 32: + return svptrue_pat_b8(SV_VL16); + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 16: + return svptrue_pat_b16(SV_VL8); + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 8: + return svptrue_pat_b32(SV_VL4); + case 4: + return svptrue_pat_b32(SV_VL2); + default: + return svptrue_pat_b32(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 4: + return svptrue_pat_b64(SV_VL2); + default: + return svptrue_pat_b64(SV_VL1); + } +} +#endif +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + return svptrue_pat_b32(Lanes(d) == 4 ? SV_VL2 : SV_VL1); +} +template +svbool_t MaskLowerHalf(D /*d*/) { + return svptrue_pat_b64(SV_VL1); +} +#endif // HWY_TARGET == HWY_SVE2_128 +#if HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128 +template +svbool_t MaskLowerHalf(D d) { + return FirstN(d, Lanes(d) / 2); +} +#endif + +template +svbool_t MaskUpperHalf(D d) { + // TODO(janwas): WHILEGE on pow2 SVE2 + if (HWY_SVE_IS_POW2 && IsFull(d)) { + return Not(MaskLowerHalf(d)); + } + + // For Splice to work as intended, make sure bits above Lanes(d) are zero. + return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); +} + +// Right-shift vector pair by constexpr; can be used to slide down (=N) or up +// (=Lanes()-N). +#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ + } +HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) +#undef HWY_SVE_EXT + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { + return IfThenElse(detail::MaskLowerHalf(d), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatEvenBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); +#endif + } + return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes + if (detail::IsFull(d)) { + return detail::Ext(hi, lo); + } +#endif + return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatOddBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); +#endif + } + const svbool_t mask_upper = detail::MaskUpperHalf(d); + const V lo_upper = detail::Splice(lo, lo, mask_upper); + return IfThenElse(mask_upper, hi, lo_upper); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(const D d, const V2 hi, const V2 lo) { + return ConcatLowerLower(d, hi, lo); +} + +// ------------------------------ ZeroExtendVector +template +HWY_API V ZeroExtendVector(const D d, const V lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ Lower/UpperHalf + +template +HWY_API V LowerHalf(D2 /* tag */, const V v) { + return v; +} + +template +HWY_API V LowerHalf(const V v) { + return v; +} + +template +HWY_API V UpperHalf(const DH dh, const V v) { + const Twice d; + // Cast so that we support bfloat16_t. + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes + return BitCast(d, detail::Ext(vu, vu)); +#else + const MFromD mask = detail::MaskUpperHalf(du); + return BitCast(d, detail::Splice(vu, vu, mask)); +#endif +} + +// ================================================== REDUCE + +// These return T, whereas the Highway op returns a broadcasted vector. +namespace detail { +#define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ + using T = HWY_SVE_T(BASE, BITS); \ + using TU = MakeUnsigned; \ + constexpr uint64_t kMask = LimitsMax(); \ + return static_cast(static_cast( \ + static_cast(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ + } + +#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(pg, v); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) +// NaN if all are +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) + +#undef HWY_SVE_REDUCE +#undef HWY_SVE_REDUCE_ADD +} // namespace detail + +template +V SumOfLanes(D d, V v) { + return Set(d, detail::SumOfLanesM(detail::MakeMask(d), v)); +} + +template +V MinOfLanes(D d, V v) { + return Set(d, detail::MinOfLanesM(detail::MakeMask(d), v)); +} + +template +V MaxOfLanes(D d, V v) { + return Set(d, detail::MaxOfLanesM(detail::MakeMask(d), v)); +} + + +// ================================================== SWIZZLE + +// ------------------------------ GetLane + +namespace detail { +#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_T(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) +#undef HWY_SVE_GET_LANE +} // namespace detail + +template +HWY_API TFromV GetLane(V v) { + return detail::GetLaneM(v, detail::PFalse()); +} + +// ------------------------------ ExtractLane +template +HWY_API TFromV ExtractLane(V v, size_t i) { + return detail::GetLaneM(v, FirstN(DFromV(), i)); +} + +// ------------------------------ InsertLane (IfThenElse) +template +HWY_API V InsertLane(const V v, size_t i, TFromV t) { + const DFromV d; + const auto is_i = detail::EqN(Iota(d, 0), static_cast>(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ DupEven + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) +} // namespace detail + +template +HWY_API V DupEven(const V v) { + return detail::InterleaveEven(v, v); +} + +// ------------------------------ DupOdd + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) +} // namespace detail + +template +HWY_API V DupOdd(const V v) { + return detail::InterleaveOdd(v, v); +} + +// ------------------------------ OddEven + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +#define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ + return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) +#undef HWY_SVE_ODD_EVEN + +template +HWY_API V OddEven(const V odd, const V even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); +} + +#else + +template +HWY_API V OddEven(const V odd, const V even) { + const auto odd_in_even = detail::Ext<1>(odd, odd); + return detail::InterleaveEven(even, odd_in_even); +} + +#endif // HWY_TARGET + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V odd, const V even) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperLower(d, odd, even); +#elif HWY_TARGET == HWY_SVE2_128 + (void)odd; + (void)d; + return even; +#else + const RebindToUnsigned du; + using TU = TFromD; + constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); + const auto idx_block = ShiftRight(Iota(du, 0)); + const auto lsb = detail::AndN(idx_block, static_cast(1)); + const svbool_t is_even = detail::EqN(lsb, static_cast(0)); + return IfThenElse(is_even, even, odd); +#endif +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + using TI = TFromV; + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size mismatch"); + const RebindToUnsigned du; + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllTrue(du, detail::LtN(indices, static_cast(Lanes(d))))); +#else + (void)d; +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +// <32bit are not part of Highway API, but used in Broadcast. +#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(v, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) +#undef HWY_SVE_TABLE + +// ------------------------------ SwapAdjacentBlocks (TableLookupLanes) + +namespace detail { + +template +constexpr size_t LanesPerBlock(Simd /* tag */) { + // We might have a capped vector smaller than a block, so honor that. + return HWY_MIN(16 / sizeof(T), detail::ScaleByPower(N, kPow2)); +} + +} // namespace detail + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerUpper(d, v, v); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#else + const RebindToUnsigned du; + constexpr auto kLanesPerBlock = + static_cast>(detail::LanesPerBlock(d)); + const VFromD idx = detail::XorN(Iota(du, 0), kLanesPerBlock); + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ Reverse + +namespace detail { + +#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) +#undef HWY_SVE_REVERSE + +} // namespace detail + +template +HWY_API V Reverse(D d, V v) { + using T = TFromD; + const auto reversed = detail::ReverseFull(v); + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; + // Shift right to remove extra (non-pow2 and remainder) lanes. + // TODO(janwas): on SVE2, use WHILEGE. + // Avoids FirstN truncating to the return vector size. Must also avoid Not + // because that is limited to SV_POW2. + const ScalableTag dfull; + const svbool_t all_true = detail::AllPTrue(dfull); + const size_t all_lanes = detail::AllHardwareLanes(hwy::SizeTag()); + const svbool_t mask = + svnot_b_z(all_true, FirstN(dfull, all_lanes - Lanes(d))); + return detail::Splice(reversed, reversed, mask); +} + +// ------------------------------ Reverse2 + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { // 3210 +#if HWY_TARGET == HWY_SVE2_128 + if (detail::IsFull(d)) { + return detail::Ext<1>(v, v); + } +#endif + (void)d; + const auto odd_in_even = detail::Ext<1>(v, v); // x321 + return detail::InterleaveEven(odd_in_even, v); // 2301 +} +// ------------------------------ Reverse4 (TableLookupLanes) +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + if (HWY_TARGET == HWY_SVE_256 && sizeof(TFromD) == 8 && + detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse8 (TableLookupLanes) +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 7); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Compress (PromoteTo) + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + // Optimization for 64-bit lanes (could also be applied to 32-bit, but that + // requires a larger table). + enum { value = (sizeof(T) == 8) }; +#else + enum { value = 0 }; +#endif // HWY_TARGET == HWY_SVE_256 +}; + +#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) +HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) +#else +HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) +#endif +#undef HWY_SVE_COMPRESS + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompress64x4Tables + 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, + 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, + 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +} + +#endif // HWY_TARGET == HWY_SVE_256 +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 + // unchanged and map everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(maskLL, mask)); +} + +#endif // HWY_TARGET == HWY_SVE2_128 + +template +HWY_API V Compress(V v, svbool_t mask16) { + static_assert(!IsSame(), "Must use overload"); + const DFromV d16; + + // Promote vector and mask to 32-bit + const RepartitionToWide dw; + const auto v32L = PromoteTo(dw, v); + const auto v32H = detail::PromoteUpperTo(dw, v); + const svbool_t mask32L = svunpklo_b(mask16); + const svbool_t mask32H = svunpkhi_b(mask16); + + const auto compressedL = Compress(v32L, mask32L); + const auto compressedH = Compress(v32H, mask32H); + + // Demote to 16-bit (already in range) - separately so we can splice + const V evenL = BitCast(d16, compressedL); + const V evenH = BitCast(d16, compressedH); + const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half + const V v16H = detail::ConcatEvenFull(evenH, evenH); + + // We need to combine two vectors of non-constexpr length, so the only option + // is Splice, which requires us to synthesize a mask. NOTE: this function uses + // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. + const size_t countL = detail::CountTrueFull(dw, mask32L); + const auto compressed_maskL = FirstN(d16, countL); + return detail::Splice(v16H, v16L, compressed_maskL); +} + +// Must treat float16_t as integers so we can ConcatEven. +HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { + const DFromV df; + const RebindToSigned di; + return BitCast(df, Compress(BitCast(di, v), mask16)); +} + +// ------------------------------ CompressNot + +// 2 or 4 bytes +template , HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API V CompressNot(V v, const svbool_t mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API V CompressNot(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map + // 01 to 10, and everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(mask, maskLL)); +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompressNot64x4Tables + 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, + 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, + 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif // HWY_TARGET == HWY_SVE_256 + + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 + (void)mask; + return v; +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + uint64_t bits = 0; // predicate reg is 32-bit + CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient + // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. + const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); + // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. + alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, + 0, 1, 2, 3, 0, 1, 2, 3}; + const ScalableTag d; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif + + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + const svbool_t store_mask = FirstN(d, count); + BlendedStore(Compress(v, mask), store_mask, d, unaligned); + return count; +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes + +// Prevent accidentally using these for 128-bit vectors - should not be +// necessary. +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return detail::AndNotN(static_cast(LanesPerBlock(d) - 1), iota0); +} + +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint8_t idx_mod = + svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, + 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, + 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, + 15 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint16_t idx_mod = + svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint32_t idx_mod = + svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint64_t idx_mod = + svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, detail::Ext(hi8, lo8)); +#else + const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); + const auto lo_down = detail::Ext(lo8, lo8); + const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +#endif +} + +// ------------------------------ Shuffle2301 +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return Reverse2(d, v); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { +#if HWY_TARGET == HWY_SVE_256 + if (detail::IsFull(d)) { + return SwapAdjacentBlocks(v); + } else if (detail::IsFull(Twice())) { + return v; + } +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#endif + const Repartition du64; + return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); +} + +// ------------------------------ TableLookupBytes + +template +HWY_API VI TableLookupBytes(const V v, const VI idx) { + const DFromV d; + const Repartition du8; +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); +#else + const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); + const auto idx8 = Add(BitCast(du8, idx), offsets128); + return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); +#endif +} + +template +HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { + const DFromV d; + // Mask size must match vector type, so cast everything to this type. + const Repartition di8; + + auto idx8 = BitCast(di8, idx); + const auto msb = detail::LtN(idx8, 0); + + const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Broadcast + +#if HWY_TARGET == HWY_SVE2_128 +namespace detail { +#define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v, kLane); \ + } + +HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) +#undef HWY_SVE_BROADCAST +} // namespace detail +#endif + +template +HWY_API V Broadcast(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); +#if HWY_TARGET == HWY_SVE2_128 + return detail::BroadcastLane(v); +#else + auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); + if (kLane != 0) { + idx = detail::AddN(idx, kLane); + } + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(D d, const V v) { + const auto zero = Zero(d); + const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); +#if HWY_TARGET == HWY_SVE2_128 + return shifted; +#else + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + return IfThenElse(detail::FirstNPerBlock(d), zero, shifted); +#endif +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template > +HWY_API V ShiftRightLanes(D d, V v) { + // For capped/fractional vectors, clear upper lanes so we shift in zeros. + if (!detail::IsFull(d)) { + v = IfThenElseZero(detail::MakeMask(d), v); + } + +#if HWY_TARGET == HWY_SVE2_128 + return detail::Ext(Zero(d), v); +#else + const auto shifted = detail::Ext(v, v); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const svbool_t mask = detail::FirstNPerBlock(d); + return IfThenElseZero(mask, shifted); +#endif +} + +// ------------------------------ ShiftLeftBytes + +template > +HWY_API V ShiftLeftBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(const V a, const V b) { + return BitCast(DW(), InterleaveLower(D(), a, b)); +} + +// ------------------------------ ZipUpper +template >> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ PromoteTo bfloat16 (ZipLower) +template +HWY_API svfloat32_t PromoteTo(Simd df32, + const svuint16_t v) { + return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), v)); +} + +// ------------------------------ ReorderDemote2To (OddEven) + +template +HWY_API svuint16_t ReorderDemote2To(Simd dbf16, + svfloat32_t a, svfloat32_t b) { + const RebindToUnsigned du16; + const Repartition du32; + const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +template +HWY_API svint16_t ReorderDemote2To(Simd d16, svint32_t a, + svint32_t b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + (void)d16; + const svint16_t a_in_even = svqxtnb_s32(a); + return svqxtnt_s32(a_in_even, b); +#else + const Half dh; + const svint16_t a16 = BitCast(dh, detail::SaturateI(a)); + const svint16_t b16 = BitCast(dh, detail::SaturateI(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +// ------------------------------ ZeroIfNegative (Lt, IfThenElse) +template +HWY_API V ZeroIfNegative(const V v) { + return IfThenZeroElse(detail::LtN(v, 0), v); +} + +// ------------------------------ BroadcastSignBit (ShiftRight) +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + const svbool_t m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ AverageRound (ShiftRight) + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +#else +template +V AverageRound(const V a, const V b) { + return ShiftRight<1>(detail::AddN(Add(a, b), 1)); +} +#endif // HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const svuint8_t iota = Iota(du, 0); + + // Load correct number of bytes (bits/8) with 7 zeros after each. + const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // There may be up to 128 bits; avoid reading past the end. + const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); + + // Replicate bytes 16x such that each lane contains the bit that governs it. + const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, rep16), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, + // so we can skip computing the actual length (Lanes(du)+7)/8. + const svuint8_t bytes = svld1(FirstN(du8, 8), bits); + + // Replicate bytes 32x such that each lane contains the bit that governs it. + const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); + + // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. + const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); + + return TestBit(BitCast(du, rep32), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + + // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. + // The "at least 8 byte" guarantee in quick_reference ensures this is safe. + uint32_t mask_bits; + CopyBytes<4>(bits, &mask_bits); // copy from bytes + const auto vbits = Set(du, mask_bits); + + // 2 ^ {0,1, .., 31}, will not have more lanes than that. + const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); + + return TestBit(vbits, bit); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return svdup_n_u8_z(m, 1); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d8; + const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); + return detail::ConcatEvenFull(b16, b16); // lower half +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return U8FromU32(svdup_n_u32_z(m, 1)); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d32; + const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); + return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half +} + +// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. +HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { + const ScalableTag d8; + const ScalableTag d16; + const ScalableTag d32; + const ScalableTag d64; + // TODO(janwas): could use SVE2 BDEP, but it's optional. + x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); + x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); + x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); + return BitCast(d64, x); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +// TODO(janwas): specialize for HWY_SVE_256 +template +HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { + svuint64_t bits_in_u64 = + detail::BitsFromBool(detail::BoolFromMask>(m)); + + const size_t num_bits = Lanes(d); + const size_t num_bytes = (num_bits + 8 - 1) / 8; // Round up, see below + + // Truncate each u64 to 8 bits and store to u8. + svst1b_u64(FirstN(ScalableTag(), num_bytes), bits, bits_in_u64); + + // Non-full byte, need to clear the undefined upper bits. Can happen for + // capped/fractional vectors or large T and small hardware vectors. + if (num_bits < 8) { + const int mask = static_cast((1ull << num_bits) - 1); + bits[0] = static_cast(bits[0] & mask); + } + // Else: we wrote full bytes because num_bits is a power of two >= 8. + + return num_bytes; +} + +// ------------------------------ CompressBits (LoadMaskBits) +template , HWY_IF_NOT_LANE_SIZE_D(D, 1)> +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(D(), bits)); +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ MulEven (InterleaveEven) + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +namespace detail { +#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +#undef HWY_SVE_MUL_EVEN +} // namespace detail +#endif + +template >> +HWY_API VFromD MulEven(const V a, const V b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + return BitCast(DW(), detail::MulEvenNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveEven(lo, hi)); +#endif +} + +HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd df32, + svuint16_t a, svuint16_t b, + const svfloat32_t sum0, + svfloat32_t& sum1) { + // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16. + const RebindToUnsigned du32; + // Using shift/and instead of Zip leads to the odd/even order that + // RearrangeToOddPlusEven prefers. + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +template +HWY_API svint32_t ReorderWidenMulAccumulate(Simd d32, + svint16_t a, svint16_t b, + const svint32_t sum0, + svint32_t& sum1) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + (void)d32; + sum1 = svmlalt_s32(sum1, a, b); + return svmlalb_s32(sum0, a, b); +#else + const svbool_t pg = detail::PTrue(d32); + // Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers. + // Fortunately SVE has sign-extension for the even lanes. + const svint32_t ae = svexth_s32_x(pg, BitCast(d32, a)); + const svint32_t be = svexth_s32_x(pg, BitCast(d32, b)); + const svint32_t ao = ShiftRight<16>(BitCast(d32, a)); + const svint32_t bo = ShiftRight<16>(BitCast(d32, b)); + sum1 = svmla_s32_x(pg, sum1, ao, bo); + return svmla_s32_x(pg, sum0, ae, be); +#endif +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes. + return Add(sum0, sum1); +} + +// ------------------------------ AESRound / CLMul + +#if defined(__ARM_FEATURE_SVE2_AES) || \ + ((HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128) && \ + HWY_HAVE_RUNTIME_DISPATCH) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { + // It is not clear whether E and MC fuse like they did on NEON. + const svuint8_t zero = svdup_n_u8(0); + return Xor(svaesmc_u8(svaese_u8(state, zero)), round_key); +} + +HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { + return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); +} + +HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { + return svpmullb_pair(a, b); +} + +HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { + return svpmullt_pair(a, b); +} + +#endif // __ARM_FEATURE_SVE2_AES + +// ------------------------------ Lt128 + +namespace detail { +#define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ + return sv##OP##_b##BITS(m, m); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool +#undef HWY_SVE_DUP + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. + const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); + // Duplicate upper lane into lower. + return DupOdd(ltHx); +} +#endif +} // namespace detail + +template +HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Lt128Vec(d, a, b)); +#else + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + const svbool_t ltHL = Lt(a, b); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); + // Duplicate upper lane into lower. + return detail::DupOddB(d, ltHx); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Lt128Upper + +template +HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t ltHL = Lt(a, b); + return detail::DupOddB(d, ltHL); +} + +// ------------------------------ Eq128, Ne128 + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +namespace detail { + +template +HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); + // Duplicate upper and lower. + const svuint64_t eqHH = DupOdd(eqHL); + const svuint64_t eqLL = DupEven(eqHL); + return And(eqLL, eqHH); +} + +template +HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t neHL = VecFromMask(d, Ne(a, b)); + // Duplicate upper and lower. + const svuint64_t neHH = DupOdd(neHL); + const svuint64_t neLL = DupEven(neHL); + return Or(neLL, neHH); +} + +} // namespace detail +#endif + +template +HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Eq128Vec(d, a, b)); +#else + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHL = Eq(a, b); + const svbool_t eqHH = detail::DupOddB(d, eqHL); + const svbool_t eqLL = detail::DupEvenB(d, eqHL); + return And(eqLL, eqHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +template +HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Ne128Vec(d, a, b)); +#else + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t neHL = Ne(a, b); + const svbool_t neHH = detail::DupOddB(d, neHL); + const svbool_t neLL = detail::DupEvenB(d, neHL); + return Or(neLL, neHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Eq128Upper, Ne128Upper + +template +HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHL = Eq(a, b); + return detail::DupOddB(d, eqHL); +} + +template +HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t neHL = Ne(a, b); + return detail::DupOddB(d, neHL); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +#else + return IfThenElse(Lt128(d, a, b), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +#else + return IfThenElse(Lt128(d, b, a), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_IF_FLOAT_V +#undef HWY_IF_LANE_SIZE_V +#undef HWY_SVE_ALL_PTRUE +#undef HWY_SVE_D +#undef HWY_SVE_FOREACH +#undef HWY_SVE_FOREACH_F +#undef HWY_SVE_FOREACH_F16 +#undef HWY_SVE_FOREACH_F32 +#undef HWY_SVE_FOREACH_F64 +#undef HWY_SVE_FOREACH_I +#undef HWY_SVE_FOREACH_I08 +#undef HWY_SVE_FOREACH_I16 +#undef HWY_SVE_FOREACH_I32 +#undef HWY_SVE_FOREACH_I64 +#undef HWY_SVE_FOREACH_IF +#undef HWY_SVE_FOREACH_U +#undef HWY_SVE_FOREACH_U08 +#undef HWY_SVE_FOREACH_U16 +#undef HWY_SVE_FOREACH_U32 +#undef HWY_SVE_FOREACH_U64 +#undef HWY_SVE_FOREACH_UI +#undef HWY_SVE_FOREACH_UI08 +#undef HWY_SVE_FOREACH_UI16 +#undef HWY_SVE_FOREACH_UI32 +#undef HWY_SVE_FOREACH_UI64 +#undef HWY_SVE_FOREACH_UIF3264 +#undef HWY_SVE_PTRUE +#undef HWY_SVE_RETV_ARGPV +#undef HWY_SVE_RETV_ARGPVN +#undef HWY_SVE_RETV_ARGPVV +#undef HWY_SVE_RETV_ARGV +#undef HWY_SVE_RETV_ARGVN +#undef HWY_SVE_RETV_ARGVV +#undef HWY_SVE_RETV_ARGVVV +#undef HWY_SVE_T +#undef HWY_SVE_UNDEFINED +#undef HWY_SVE_V + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/emu128-inl.h b/third_party/highway/hwy/ops/emu128-inl.h new file mode 100644 index 0000000000..7fb934def0 --- /dev/null +++ b/third_party/highway/hwy/ops/emu128-inl.h @@ -0,0 +1,2503 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include +#include // std::abs, std::isnan + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +using Full128 = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec128 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() = default; + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h + // relies on this for LoadInterleaved*. CAVEAT: this method of padding + // prevents using range for, especially in SumOfLanes, where it would be + // incorrect. Moving padding to another field would require handling the case + // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. + T raw[16 / sizeof(T)] = {}; +}; + +// 0 or FF..FF, same size as Vec128. +template +struct Mask128 { + using Raw = hwy::MakeUnsigned; + static HWY_INLINE Raw FromBool(bool b) { + return b ? static_cast(~Raw{0}) : 0; + } + + // Must match the size of Vec128. + Raw bits[16 / sizeof(T)] = {}; +}; + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template +HWY_API Vec128 BitCast(Simd /* tag */, Vec128 v) { + Vec128 to; + CopySameSize(&v, &to); + return to; +} + +// ------------------------------ Set + +template +HWY_API Vec128 Zero(Simd /* tag */) { + Vec128 v; + ZeroBytes(v.raw); + return v; +} + +template +using VFromD = decltype(Zero(D())); + +template +HWY_API Vec128 Set(Simd /* tag */, const T2 t) { + Vec128 v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(t); + } + return v; +} + +template +HWY_API Vec128 Undefined(Simd d) { + return Zero(d); +} + +template +HWY_API Vec128 Iota(const Simd /* tag */, T2 first) { + Vec128 v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return v; +} + +// ================================================== LOGICAL + +// ------------------------------ Not +template +HWY_API Vec128 Not(const Vec128 v) { + const Simd d; + const RebindToUnsigned du; + using TU = TFromD; + VFromD vu = BitCast(du, v); + for (size_t i = 0; i < N; ++i) { + vu.raw[i] = static_cast(~vu.raw[i]); + } + return BitCast(d, vu); +} + +// ------------------------------ And +template +HWY_API Vec128 And(const Vec128 a, const Vec128 b) { + const Simd d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] &= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +// ------------------------------ AndNot +template +HWY_API Vec128 AndNot(const Vec128 a, const Vec128 b) { + return And(Not(a), b); +} + +// ------------------------------ Or +template +HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { + const Simd d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] |= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +// ------------------------------ Xor +template +HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { + const Simd d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] ^= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 + +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(const Vec128 o, const Vec128 a1, + const Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd()), sign)); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[i] < 0 ? T(-1) : T(0); + } + return v; +} + +// ------------------------------ Mask + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 mask) { + Mask128 to; + CopySameSize(&mask, &to); + return to; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + Mask128 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +Vec128 VecFromMask(const Mask128 mask) { + Vec128 v; + CopySameSize(&mask, &v); + return v; +} + +template +Vec128 VecFromMask(Simd /* tag */, const Mask128 mask) { + return VecFromMask(mask); +} + +template +HWY_API Mask128 FirstN(Simd /*tag*/, size_t n) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(i < n); + } + return m; +} + +// Returns mask ? yes : no. +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, const Vec128 no) { + return IfVecThenElse(VecFromMask(mask), yes, no); +} + +template +HWY_API Vec128 IfThenElseZero(const Mask128 mask, + const Vec128 yes) { + return IfVecThenElse(VecFromMask(mask), yes, Zero(Simd())); +} + +template +HWY_API Vec128 IfThenZeroElse(const Mask128 mask, + const Vec128 no) { + return IfVecThenElse(VecFromMask(mask), Zero(Simd()), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[i] < 0 ? yes.raw[i] : no.raw[i]; + } + return v; +} + +template +HWY_API Vec128 ZeroIfNegative(const Vec128 v) { + return IfNegativeThenElse(v, Zero(Simd()), v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec128 ShiftLeft(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) << kBits; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> kBits); + } +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU shifted = static_cast(static_cast(v.raw[i]) >> kBits); + const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast(sign << sign_shift); + v.raw[i] = static_cast(shifted | upper); + } + } else { // T is unsigned + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> kBits); + } + } +#endif + return v; +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template +struct RotateRight { + template + HWY_INLINE Vec128 operator()(const Vec128 v) const { + return Or(ShiftRight(v), ShiftLeft(v)); + } +}; + +template <> +struct RotateRight<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) const { + return v; + } +}; + +} // namespace detail + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight()(v); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) << bits; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits); + } +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU shifted = static_cast(static_cast(v.raw[i]) >> bits); + const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast(sign << sign_shift); + v.raw[i] = static_cast(shifted | upper); + } + } else { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits); // unsigned, logical shift + } + } +#endif + return v; +} + +// ------------------------------ Shl + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) + << bits.raw[i]; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); + } +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU shifted = + static_cast(static_cast(v.raw[i]) >> bits.raw[i]); + const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; + const size_t sign_shift = static_cast( + static_cast(sizeof(TU)) * 8 - 1 - bits.raw[i]); + const TU upper = static_cast(sign << sign_shift); + v.raw[i] = static_cast(shifted | upper); + } + } else { // T is unsigned + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); + } + } +#endif + return v; +} + +// ================================================== ARITHMETIC + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Add(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 + b64) & static_cast(~T(0))); + } + return a; +} +template +HWY_INLINE Vec128 Sub(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 - b64) & static_cast(~T(0))); + } + return a; +} + +template +HWY_INLINE Vec128 Add(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] += b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Sub(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] -= b.raw[i]; + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 operator-(Vec128 a, const Vec128 b) { + return detail::Sub(hwy::IsFloatTag(), a, b); +} +template +HWY_API Vec128 operator+(Vec128 a, const Vec128 b) { + return detail::Add(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ SumsOf8 + +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +// ------------------------------ SaturatedAdd +template +HWY_API Vec128 SaturatedAdd(Vec128 a, const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast( + HWY_MIN(HWY_MAX(hwy::LowestValue(), a.raw[i] + b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ SaturatedSub +template +HWY_API Vec128 SaturatedSub(Vec128 a, const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast( + HWY_MIN(HWY_MAX(hwy::LowestValue(), a.raw[i] - b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ AverageRound +template +HWY_API Vec128 AverageRound(Vec128 a, const Vec128 b) { + static_assert(!IsSigned(), "Only for unsigned"); + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((a.raw[i] + b.raw[i] + 1) / 2); + } + return a; +} + +// ------------------------------ Abs + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Abs(SignedTag /*tag*/, Vec128 a) { + for (size_t i = 0; i < N; ++i) { + const T s = a.raw[i]; + const T min = hwy::LimitsMin(); + a.raw[i] = static_cast((s >= 0 || s == min) ? a.raw[i] : -s); + } + return a; +} + +template +HWY_INLINE Vec128 Abs(hwy::FloatTag /*tag*/, Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = std::abs(v.raw[i]); + } + return v; +} + +} // namespace detail + +template +HWY_API Vec128 Abs(Vec128 a) { + return detail::Abs(hwy::TypeTag(), a); +} + +// ------------------------------ Min/Max + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Min(hwy::NonFloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::NonFloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + return a; +} + +template +HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (std::isnan(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (std::isnan(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (std::isnan(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (std::isnan(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 Min(Vec128 a, const Vec128 b) { + return detail::Min(hwy::IsFloatTag(), a, b); +} + +template +HWY_API Vec128 Max(Vec128 a, const Vec128 b) { + return detail::Max(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API Vec128 Neg(hwy::NonFloatTag /*tag*/, Vec128 v) { + return Zero(Simd()) - v; +} + +template +HWY_API Vec128 Neg(hwy::FloatTag /*tag*/, Vec128 v) { + return Xor(v, SignBit(Simd())); +} + +} // namespace detail + +template +HWY_API Vec128 Neg(Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Mul/Div + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Mul(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] *= b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Mul(SignedTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * + static_cast(b.raw[i])); + } + return a; +} + +template +HWY_INLINE Vec128 Mul(UnsignedTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * + static_cast(b.raw[i])); + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 operator*(Vec128 a, const Vec128 b) { + return detail::Mul(hwy::TypeTag(), a, b); +} + +template +HWY_API Vec128 operator/(Vec128 a, const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] /= b.raw[i]; + } + return a; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((int32_t{a.raw[i]} * b.raw[i]) >> 16); + } + return a; +} +template +HWY_API Vec128 MulHigh(Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the + // result is the same but this way it is also defined. + a.raw[i] = static_cast( + (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> + 16); + } + return a; +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((2 * a.raw[i] * b.raw[i] + 32768) >> 16); + } + return a; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const int64_t a64 = a.raw[i]; + mul.raw[i / 2] = a64 * b.raw[i]; + } + return mul; +} +template +HWY_API Vec128 MulEven(Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const uint64_t a64 = a.raw[i]; + mul.raw[i / 2] = a64 * b.raw[i]; + } + return mul; +} + +template +HWY_API Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const int64_t a64 = a.raw[i + 1]; + mul.raw[i / 2] = a64 * b.raw[i + 1]; + } + return mul; +} +template +HWY_API Vec128 MulOdd(Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const uint64_t a64 = a.raw[i + 1]; + mul.raw[i / 2] = a64 * b.raw[i + 1]; + } + return mul; +} + +template +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The result is arbitrary. + v.raw[i] = (std::abs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; + } + return v; +} + +template +HWY_API Vec128 AbsDiff(Vec128 a, const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec128 MulAdd(Vec128 mul, const Vec128 x, + const Vec128 add) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, const Vec128 x, + const Vec128 add) { + return add - mul * x; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, const Vec128 x, + const Vec128 sub) { + return mul * x - sub; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, const Vec128 x, + const Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + const float half = v.raw[i] * 0.5f; + uint32_t bits; + CopySameSize(&v.raw[i], &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &v.raw[i]); + // One Newton-Raphson iteration + v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); + } + return v; +} + +template +HWY_API Vec128 Sqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = std::sqrt(v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec128 Round(Vec128 v) { + using TI = MakeSigned; + const Vec128 a = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN + continue; + } + const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw[i] + bias); + if (rounded == 0) { + v.raw[i] = v.raw[i] < 0 ? T{-0} : T{0}; + continue; + } + const T rounded_f = static_cast(rounded); + // Round to even + if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + v.raw[i] = static_cast(rounded - (v.raw[i] < T(0) ? -1 : 1)); + continue; + } + v.raw[i] = rounded_f; + } + return v; +} + +// Round-to-nearest even. +template +HWY_API Vec128 NearestInt(const Vec128 v) { + using T = float; + using TI = int32_t; + + const Vec128 abs = Abs(v); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = std::signbit(v.raw[i]); + + if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs.raw[i] <= static_cast(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + ret.raw[i] = static_cast(v.raw[i]); + continue; + } + const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = static_cast(rounded); + // Round to even + if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API Vec128 Trunc(Vec128 v) { + using TI = MakeSigned; + const Vec128 abs = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(abs.raw[i] <= MantissaEnd())) { // Huge or NaN + continue; + } + const TI truncated = static_cast(v.raw[i]); + if (truncated == 0) { + v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; + continue; + } + v.raw[i] = static_cast(truncated); + } + return v; +} + +// Toward +infinity, aka ceiling +template +Vec128 Ceil(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool positive = v.raw[i] > Float(0.0); + + Bits bits; + CopySameSize(&v.raw[i], &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => 0 or 1. + if (exponent < 0) { + v.raw[i] = positive ? Float{1} : Float{-0.0}; + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &v.raw[i]); + } + return v; +} + +// Toward -infinity, aka floor +template +Vec128 Floor(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool negative = v.raw[i] < Float(0.0); + + Bits bits; + CopySameSize(&v.raw[i], &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => -1 or 0. + if (exponent < 0) { + v.raw[i] = negative ? Float(-1.0) : Float(0.0); + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + Mask128 ret; + for (size_t i = 0; i < N; ++i) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned bits; + CopySameSize(&v.raw[i], &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + ret.bits[i] = Mask128::FromBool(bits > ExponentMask()); + } + return ret; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + using VI = VFromD; + using VU = VFromD; + const VU vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VI exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] == b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] != b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] < b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] > b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator<=(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] <= b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>=(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] >= b.raw[i]); + } + return m; +} + +// ------------------------------ Lt128 + +// Only makes sense for full vectors of u64. +HWY_API Mask128 Lt128(Simd /* tag */, + Vec128 a, const Vec128 b) { + const bool lt = + (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +HWY_API Mask128 Lt128Upper(Simd /* tag */, + Vec128 a, + const Vec128 b) { + const bool lt = a.raw[1] < b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +// ------------------------------ Eq128 + +// Only makes sense for full vectors of u64. +HWY_API Mask128 Eq128(Simd /* tag */, + Vec128 a, const Vec128 b) { + const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +HWY_API Mask128 Ne128(Simd /* tag */, + Vec128 a, const Vec128 b) { + const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +HWY_API Mask128 Eq128Upper(Simd /* tag */, + Vec128 a, + const Vec128 b) { + const bool eq = a.raw[1] == b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +HWY_API Mask128 Ne128Upper(Simd /* tag */, + Vec128 a, + const Vec128 b) { + const bool ne = a.raw[1] != b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +// ------------------------------ Min128, Max128 (Lt128) + +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Simd /* tag */, + const T* HWY_RESTRICT aligned) { + Vec128 v; + CopyBytes(aligned, v.raw); // copy from array + return v; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API Vec128 LoadDup128(Simd d, + const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + CopyBytes(v.raw, aligned); // copy to array +} + +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(const Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + for (size_t i = 0; i < N; ++i) { + if (m.bits[i]) p[i] = v.raw[i]; + } +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +// We implement those here because scalar code is likely faster than emulation +// via shuffles. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, + Vec128& v2) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + alignas(16) T buf2[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, + Vec128& v2, Vec128& v3) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + alignas(16) T buf2[N]; + alignas(16) T buf3[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + buf3[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); + v3 = Load(d, buf3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + } +} + +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + } +} + +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + *unaligned++ = v3.raw[i]; + } +} + +// ------------------------------ Stream + +template +HWY_API void Stream(const Vec128 v, Simd d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, T* base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + for (size_t i = 0; i < N; ++i) { + uint8_t* const base8 = reinterpret_cast(base) + offset.raw[i]; + CopyBytes(&v.raw[i], base8); // copy to bytes + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT base, const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + for (size_t i = 0; i < N; ++i) { + base[index.raw[i]] = v.raw[i]; + } +} + +// ------------------------------ Gather + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, const T* base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + Vec128 v; + for (size_t i = 0; i < N; ++i) { + const uint8_t* base8 = + reinterpret_cast(base) + offset.raw[i]; + CopyBytes(base8, &v.raw[i]); // copy from bytes + } + return v; +} + +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + Vec128 v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = base[index.raw[i]]; + } + return v; +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + Vec128 from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // For bits Y > X, floatX->floatY and intX->intY are always representable. + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + Vec128 from) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // Prevent ubsan errors when converting float to narrower integer/float + if (std::isinf(from.raw[i]) || + std::fabs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() + : HighestValue(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + Vec128 from) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (std::isinf(from.raw[i]) || + std::fabs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() + : HighestValue(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + Vec128 from) { + static_assert(!IsFloat(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw[i] = + HWY_MIN(HWY_MAX(LimitsMin(), from.raw[i]), LimitsMax()); + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + const Repartition du32; + const Vec128 b_in_lower = ShiftRight<16>(BitCast(du32, b)); + // Avoid OddEven - we want the upper half of `a` even on big-endian systems. + const Vec128 a_mask = Set(du32, 0xFFFF0000); + return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); +} + +template +HWY_API Vec128 ReorderDemote2To(Simd /*d16*/, + Vec128 a, + Vec128 b) { + const int16_t min = LimitsMin(); + const int16_t max = LimitsMax(); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); + } + for (size_t i = 0; i < N; ++i) { + ret.raw[N + i] = static_cast(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); + } + return ret; +} + +namespace detail { + +HWY_INLINE void StoreU16ToF16(const uint16_t val, + hwy::float16_t* HWY_RESTRICT to) { + CopySameSize(&val, to); +} + +HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { + uint16_t bits16; + CopySameSize(from, &bits16); + return bits16; +} + +} // namespace detail + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + const uint16_t bits16 = detail::U16FromF16(&v.raw[i]); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + ret.raw[i] = sign ? -subnormal : subnormal; + continue; + } + + // Normalized: convert the representation directly (faster than + // ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + CopySameSize(&bits32, &ret.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = F32FromBF16(v.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + uint32_t bits32; + CopySameSize(&v.raw[i], &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + if (exp < -24) { + ZeroBytes(&ret.raw[i]); + continue; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + detail::StoreU16ToF16(narrowed, &ret.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = BF16FromF32(v.raw[i]); + } + return ret; +} + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API Vec128 ConvertTo(hwy::FloatTag /*tag*/, + Simd /* tag */, + Vec128 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax in FromT, so use double. + const double f = static_cast(from.raw[i]); + if (std::isinf(from.raw[i]) || + std::fabs(f) > static_cast(LimitsMax())) { + ret.raw[i] = + std::signbit(from.raw[i]) ? LimitsMin() : LimitsMax(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 ConvertTo(hwy::NonFloatTag /*tag*/, + Simd /* tag */, + Vec128 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // int## -> float##: no check needed + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +} // namespace detail + +template +HWY_API Vec128 ConvertTo(Simd d, Vec128 from) { + return detail::ConvertTo(hwy::IsFloatTag(), d, from); +} + +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + return DemoteTo(Simd(), v); +} + +// ------------------------------ Truncations + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFFFFFFu); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +// ================================================== COMBINE + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; +} + +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return LowerHalf(v); +} + +template +HWY_API Vec128 UpperHalf(Simd /* tag */, + Vec128 v) { + Vec128 ret; + CopyBytes(&v.raw[N / 2], ret.raw); + return ret; +} + +template +HWY_API Vec128 ZeroExtendVector(Simd /* tag */, + Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; +} + +template +HWY_API Vec128 Combine(Simd /* tag */, Vec128 hi_half, + Vec128 lo_half) { + Vec128 ret; + CopyBytes(lo_half.raw, &ret.raw[0]); + CopyBytes(hi_half.raw, &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatLowerLower(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatUpperUpper(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + CopyBytes(&lo.raw[N / 2], &ret.raw[0]); + CopyBytes(&hi.raw[N / 2], &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatLowerUpper(Simd /* tag */, + const Vec128 hi, + const Vec128 lo) { + Vec128 ret; + CopyBytes(&lo.raw[N / 2], &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatUpperLower(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(&hi.raw[N / 2], &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[i] = lo.raw[2 * i]; + } + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[N / 2 + i] = hi.raw[2 * i]; + } + return ret; +} + +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[i] = lo.raw[2 * i + 1]; + } + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[N / 2 + i] = hi.raw[2 * i + 1]; + } + return ret; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Simd /* tag */, V hi, V lo) { + V ret; + const uint8_t* HWY_RESTRICT lo8 = + reinterpret_cast(lo.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(lo8 + kBytes, ret8); + CopyBytes(hi.raw, ret8 + sizeof(T) * N - kBytes); + return ret; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + Vec128 ret; + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + ZeroBytes(ret8); + CopyBytes(v.raw, ret8 + kBytes); + return ret; +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + Vec128 ret; + const uint8_t* HWY_RESTRICT v8 = + reinterpret_cast(v.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(v8 + kBytes, ret8); + ZeroBytes(ret8 + sizeof(T) * N - kBytes); + return ret; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec128 v) { + return v.raw[0]; +} + +template +HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { + v.raw[i] = t; + return v; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + return v.raw[i]; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i + 1] = v.raw[i]; + } + return v; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i] = v.raw[i + 1]; + } + return v; +} + +template +HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { + for (size_t i = 0; i < N; i += 2) { + odd.raw[i] = even.raw[i]; + } + return odd; +} + +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + MakeSigned raw[N]; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + Indices128 ret; + CopyBytes(vec.raw, ret.raw); + return ret; +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + return IndicesFromVec(d, LoadU(Simd(), idx)); +} + +template +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[idx.raw[i]]; + } + return ret; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Simd /* tag */, + const Vec128 v) { + return v; +} + +// ------------------------------ Reverse + +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[N - 1 - i]; + } + return ret; +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; i += 2) { + ret.raw[i + 0] = v.raw[i + 1]; + ret.raw[i + 1] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; i += 4) { + ret.raw[i + 0] = v.raw[i + 3]; + ret.raw[i + 1] = v.raw[i + 2]; + ret.raw[i + 2] = v.raw[i + 1]; + ret.raw[i + 3] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API Vec128 Reverse8(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; i += 8) { + ret.raw[i + 0] = v.raw[i + 7]; + ret.raw[i + 1] = v.raw[i + 6]; + ret.raw[i + 2] = v.raw[i + 5]; + ret.raw[i + 3] = v.raw[i + 4]; + ret.raw[i + 4] = v.raw[i + 3]; + ret.raw[i + 5] = v.raw[i + 2]; + ret.raw[i + 6] = v.raw[i + 1]; + ret.raw[i + 7] = v.raw[i + 0]; + } + return ret; +} + +// ================================================== BLOCKWISE + +// ------------------------------ Shuffle* + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Reverse2(DFromV(), v); +} + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + Vec128 ret; + ret.raw[3] = v.raw[1]; + ret.raw[2] = v.raw[0]; + ret.raw[1] = v.raw[3]; + ret.raw[0] = v.raw[2]; + return ret; +} +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit"); + return Reverse2(DFromV(), v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[0]; + ret.raw[2] = v.raw[3]; + ret.raw[1] = v.raw[2]; + ret.raw[0] = v.raw[1]; + return ret; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[2]; + ret.raw[2] = v.raw[1]; + ret.raw[1] = v.raw[0]; + ret.raw[0] = v.raw[3]; + return ret; +} + +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Reverse4(DFromV(), v); +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[kLane]; + } + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec128 TableLookupBytes(const Vec128 v, + const Vec128 indices) { + const uint8_t* HWY_RESTRICT v_bytes = + reinterpret_cast(v.raw); + const uint8_t* HWY_RESTRICT idx_bytes = + reinterpret_cast(indices.raw); + Vec128 ret; + uint8_t* HWY_RESTRICT ret_bytes = + reinterpret_cast(ret.raw); + for (size_t i = 0; i < NI * sizeof(TI); ++i) { + const size_t idx = idx_bytes[i]; + // Avoid out of bounds reads. + ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; + } + return ret; +} + +template +HWY_API Vec128 TableLookupBytesOr0(const Vec128 v, + const Vec128 indices) { + // Same as TableLookupBytes, which already returns 0 if out of bounds. + return TableLookupBytes(v, indices); +} + +// ------------------------------ InterleaveLower/InterleaveUpper + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[i]; + ret.raw[2 * i + 1] = b.raw[i]; + } + return ret; +} + +// Additional overload for the optional tag (also for 256/512). +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +template +HWY_API Vec128 InterleaveUpper(Simd /* tag */, + const Vec128 a, + const Vec128 b) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[N / 2 + i]; + ret.raw[2 * i + 1] = b.raw[N / 2 + i]; + } + return ret; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(Simd /* tag */, const Mask128 mask) { + typename Mask128::Raw or_sum = 0; + for (size_t i = 0; i < N; ++i) { + or_sum |= mask.bits[i]; + } + return or_sum == 0; +} + +template +HWY_API bool AllTrue(Simd /* tag */, const Mask128 mask) { + constexpr uint64_t kAll = LimitsMax::Raw>(); + uint64_t and_sum = kAll; + for (size_t i = 0; i < N; ++i) { + and_sum &= mask.bits[i]; + } + return and_sum == kAll; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd /* tag */, + const uint8_t* HWY_RESTRICT bits) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + m.bits[i] = Mask128::FromBool((bits[idx_byte] & bit) != 0); + } + return m; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, + uint8_t* bits) { + bits[0] = 0; + if (N > 8) bits[1] = 0; // N <= 16, so max two bytes + for (size_t i = 0; i < N; ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + if (mask.bits[i]) { + bits[idx_byte] = static_cast(bits[idx_byte] | bit); + } + } + return N > 8 ? 2 : 1; +} + +template +HWY_API size_t CountTrue(Simd /* tag */, const Mask128 mask) { + size_t count = 0; + for (size_t i = 0; i < N; ++i) { + count += mask.bits[i] != 0; + } + return count; +} + +template +HWY_API size_t FindKnownFirstTrue(Simd /* tag */, + const Mask128 mask) { + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i] != 0) return i; + } + HWY_DASSERT(false); + return 0; +} + +template +HWY_API intptr_t FindFirstTrue(Simd /* tag */, + const Mask128 mask) { + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i] != 0) return static_cast(i); + } + return intptr_t{-1}; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressNot +template +HWY_API Vec128 CompressNot(Vec128 v, const Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Simd(), bits)); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + size_t count = 0; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + unaligned[count++] = v.raw[i]; + } + } + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec128 v, const Mask128 mask, + Simd d, + T* HWY_RESTRICT unaligned) { + return CompressStore(v, mask, d, unaligned); +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + const Mask128 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + const Rebind du32; + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Avoid ZipLower/Upper so this also works on big-endian systems. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd d32, Vec128 a, Vec128 b, + const Vec128 sum0, Vec128& sum1) { + using VI32 = VFromD; + // Manual sign extension requires two shifts for even lanes. + const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); + const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); + const VI32 ao = ShiftRight<16>(BitCast(d32, a)); + const VI32 bo = ShiftRight<16>(BitCast(d32, b)); + sum1 = Add(Mul(ao, bo), sum1); + return Add(Mul(ae, be), sum0); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== REDUCTIONS + +template +HWY_API Vec128 SumOfLanes(Simd d, const Vec128 v) { + T sum = T{0}; + for (size_t i = 0; i < N; ++i) { + sum += v.raw[i]; + } + return Set(d, sum); +} +template +HWY_API Vec128 MinOfLanes(Simd d, const Vec128 v) { + T min = HighestValue(); + for (size_t i = 0; i < N; ++i) { + min = HWY_MIN(min, v.raw[i]); + } + return Set(d, min); +} +template +HWY_API Vec128 MaxOfLanes(Simd d, const Vec128 v) { + T max = LowestValue(); + for (size_t i = 0; i < N; ++i) { + max = HWY_MAX(max, v.raw[i]); + } + return Set(d, max); +} + +// ================================================== OPS WITH DEPENDENCIES + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/generic_ops-inl.h b/third_party/highway/hwy/ops/generic_ops-inl.h new file mode 100644 index 0000000000..5898518467 --- /dev/null +++ b/third_party/highway/hwy/ops/generic_ops-inl.h @@ -0,0 +1,1560 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-independent types/functions defined after target-specific ops. + +#include "hwy/base.h" + +// Define detail::Shuffle1230 etc, but only when viewing the current header; +// normally this is included via highway.h, which includes ops/*.h. +#if HWY_IDE && !defined(HWY_HIGHWAY_INCLUDED) +#include "hwy/ops/emu128-inl.h" +#endif // HWY_IDE + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The lane type of a vector type, e.g. float for Vec>. +template +using LaneType = decltype(GetLane(V())); + +// Vector type, e.g. Vec128 for CappedTag. Useful as the return +// type of functions that do not take a vector argument, or as an argument type +// if the function only has a template argument for D, or for explicit type +// names instead of auto. This may be a built-in type. +template +using Vec = decltype(Zero(D())); + +// Mask type. Useful as the return type of functions that do not take a mask +// argument, or as an argument type if the function only has a template argument +// for D, or for explicit type names instead of auto. +template +using Mask = decltype(MaskFromVec(Zero(D()))); + +// Returns the closest value to v within [lo, hi]. +template +HWY_API V Clamp(const V v, const V lo, const V hi) { + return Min(Max(lo, v), hi); +} + +// CombineShiftRightBytes (and -Lanes) are not available for the scalar target, +// and RVV has its own implementation of -Lanes. +#if HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV + +template > +HWY_API V CombineShiftRightLanes(D d, const V hi, const V lo) { + constexpr size_t kBytes = kLanes * sizeof(LaneType); + static_assert(kBytes < 16, "Shift count is per-block"); + return CombineShiftRightBytes(d, hi, lo); +} + +#endif + +// Returns lanes with the most significant bit set and all other bits zero. +template +HWY_API Vec SignBit(D d) { + const RebindToUnsigned du; + return BitCast(d, Set(du, SignMask>())); +} + +// Returns quiet NaN. +template +HWY_API Vec NaN(D d) { + const RebindToSigned di; + // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus + // mantissa MSB (to indicate quiet) would be sufficient. + return BitCast(d, Set(di, LimitsMax>())); +} + +// Returns positive infinity. +template +HWY_API Vec Inf(D d) { + const RebindToUnsigned du; + using T = TFromD; + using TU = TFromD; + const TU max_x2 = static_cast(MaxExponentTimes2()); + return BitCast(d, Set(du, max_x2 >> 1)); +} + +// ------------------------------ SafeFillN + +template > +HWY_API void SafeFillN(const size_t num, const T value, D d, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = value; + } +#else + BlendedStore(Set(d, value), FirstN(d, num), d, to); +#endif +} + +// ------------------------------ SafeCopyN + +template > +HWY_API void SafeCopyN(const size_t num, D d, const T* HWY_RESTRICT from, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = from[i]; + } +#else + const Mask mask = FirstN(d, num); + BlendedStore(MaskedLoad(mask, d, from), mask, d, to); +#endif +} + +// "Include guard": skip if native instructions are available. The generic +// implementation is currently shared between x86_* and wasm_*, and is too large +// to duplicate. + +#if (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// ------------------------------ LoadInterleaved2 + +template +HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1) { + const V A = LoadU(d, unaligned + 0 * N); // v1[1] v0[1] v1[0] v0[0] + const V B = LoadU(d, unaligned + 1 * N); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +template +HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void LoadTransposedBlocks3(Simd d, + const T* HWY_RESTRICT unaligned, V& A, V& B, + V& C) { + A = LoadU(d, unaligned + 0 * N); + B = LoadU(d, unaligned + 1 * N); + C = LoadU(d, unaligned + 2 * N); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned du; + // Compact notation so these fit on one line: 12 := v1[2]. + V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 + V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 + V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + alignas(16) constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, 9, 12, 15, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, Z, Z, Z, 2, 5, + 8, 11, 14, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, Z, 1, 4, 7, 10, 13}; + alignas(16) constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, 10, 13, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, Z, Z, 0, 3, 6, + 9, 12, 15, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, Z, 2, 5, 8, 11, 14}; + alignas(16) constexpr uint8_t kIdx_v2A[16] = {2, 5, 8, 11, 14, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2B[16] = {Z, Z, Z, Z, Z, 1, 4, 7, + 10, 13, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, 0, 3, 6, 9, 12, 15}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 8-bit lanes x8 +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned du; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + alignas(16) constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, 1, 4, 7, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, 2, 5}; + alignas(16) constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, 2, 5, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, 0, 3, 6}; + alignas(16) constexpr uint8_t kIdx_v2A[16] = {2, 5, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2B[16] = {Z, Z, 0, 3, 6, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, 1, 4, 7}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 16-bit lanes x8 +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned du; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. Same as above, + // but each element of the array contains two byte indices for a lane. + constexpr uint16_t Z = 0x8080; + alignas(16) constexpr uint16_t kIdx_v0A[8] = {0x0100, 0x0706, 0x0D0C, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v0B[8] = {Z, Z, Z, 0x0302, + 0x0908, 0x0F0E, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v0C[8] = {Z, Z, Z, Z, + Z, Z, 0x0504, 0x0B0A}; + alignas(16) constexpr uint16_t kIdx_v1A[8] = {0x0302, 0x0908, 0x0F0E, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v1B[8] = {Z, Z, Z, 0x0504, + 0x0B0A, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v1C[8] = {Z, Z, Z, Z, + Z, 0x0100, 0x0706, 0x0D0C}; + alignas(16) constexpr uint16_t kIdx_v2A[8] = {0x0504, 0x0B0A, Z, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v2B[8] = {Z, Z, 0x0100, 0x0706, + 0x0D0C, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v2C[8] = {Z, Z, Z, Z, + Z, 0x0302, 0x0908, 0x0F0E}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + V A; // v0[1] v2[0] v1[0] v0[0] + V B; // v1[2] v0[2] v2[1] v1[1] + V C; // v2[3] v1[3] v0[3] v2[2] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + + const V vxx_02_03_xx = OddEven(C, B); + v0 = detail::Shuffle1230(A, vxx_02_03_xx); + + // Shuffle2301 takes the upper/lower halves of the output from one input, so + // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use + // OddEven because it may have higher throughput than Shuffle. + const V vxx_xx_10_11 = OddEven(A, B); + const V v12_13_xx_xx = OddEven(B, C); + v1 = detail::Shuffle2301(vxx_xx_10_11, v12_13_xx_xx); + + const V vxx_20_21_xx = OddEven(B, A); + v2 = detail::Shuffle3012(vxx_20_21_xx, C); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + V A; // v1[0] v0[0] + V B; // v0[1] v2[0] + V C; // v2[1] v1[1] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + v0 = OddEven(B, A); + v1 = CombineShiftRightBytes(d, C, A); + v2 = OddEven(C, B); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +// ------------------------------ LoadInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void LoadTransposedBlocks4(Simd d, + const T* HWY_RESTRICT unaligned, V& A, V& B, + V& C, V& D) { + A = LoadU(d, unaligned + 0 * N); + B = LoadU(d, unaligned + 1 * N); + C = LoadU(d, unaligned + 2 * N); + D = LoadU(d, unaligned + 3 * N); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + const Repartition d64; + using V64 = VFromD; + // 16 lanes per block; the lowest four blocks are at the bottom of A,B,C,D. + // Here int[i] means the four interleaved values of the i-th 4-tuple and + // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). + V A; // int[13..10] int[3..0] + V B; // int[17..14] int[7..4] + V C; // int[1b..18] int[b..8] + V D; // int[1f..1c] int[f..c] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + + // For brevity, the comments only list the lower block (upper = lower + 0x10) + const V v5140 = InterleaveLower(d, A, B); // int[5,1,4,0] + const V vd9c8 = InterleaveLower(d, C, D); // int[d,9,c,8] + const V v7362 = InterleaveUpper(d, A, B); // int[7,3,6,2] + const V vfbea = InterleaveUpper(d, C, D); // int[f,b,e,a] + + const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] + const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] + const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] + const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] + + const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] + const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] + const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] + const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] + + v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); + v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); + v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); + v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + // In the last step, we interleave by half of the block size, which is usually + // 8 bytes but half that for 8-bit x8 vectors. + using TW = hwy::UnsignedFromSize; + const Repartition dw; + using VW = VFromD; + + // (Comments are for 256-bit vectors.) + // 8 lanes per block; the lowest four blocks are at the bottom of A,B,C,D. + V A; // v3210[9]v3210[8] v3210[1]v3210[0] + V B; // v3210[b]v3210[a] v3210[3]v3210[2] + V C; // v3210[d]v3210[c] v3210[5]v3210[4] + V D; // v3210[f]v3210[e] v3210[7]v3210[6] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + + const V va820 = InterleaveLower(d, A, B); // v3210[a,8] v3210[2,0] + const V vec64 = InterleaveLower(d, C, D); // v3210[e,c] v3210[6,4] + const V vb931 = InterleaveUpper(d, A, B); // v3210[b,9] v3210[3,1] + const V vfd75 = InterleaveUpper(d, C, D); // v3210[f,d] v3210[7,5] + + const VW v10_b830 = // v10[b..8] v10[3..0] + BitCast(dw, InterleaveLower(d, va820, vb931)); + const VW v10_fc74 = // v10[f..c] v10[7..4] + BitCast(dw, InterleaveLower(d, vec64, vfd75)); + const VW v32_b830 = // v32[b..8] v32[3..0] + BitCast(dw, InterleaveUpper(d, va820, vb931)); + const VW v32_fc74 = // v32[f..c] v32[7..4] + BitCast(dw, InterleaveUpper(d, vec64, vfd75)); + + v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); + v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); + v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); + v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + V A; // v3210[4] v3210[0] + V B; // v3210[5] v3210[1] + V C; // v3210[6] v3210[2] + V D; // v3210[7] v3210[3] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + const V v10_ev = InterleaveLower(d, A, C); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] + const V v10_od = InterleaveLower(d, B, D); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] + const V v32_ev = InterleaveUpper(d, A, C); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] + const V v32_od = InterleaveUpper(d, B, D); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] + + v0 = InterleaveLower(d, v10_ev, v10_od); + v1 = InterleaveUpper(d, v10_ev, v10_od); + v2 = InterleaveLower(d, v32_ev, v32_od); + v3 = InterleaveUpper(d, v32_ev, v32_od); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + V A, B, C, D; + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + v0 = InterleaveLower(d, A, C); + v1 = InterleaveUpper(d, A, C); + v2 = InterleaveLower(d, B, D); + v3 = InterleaveUpper(d, B, D); +} + +// Any T x1 +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void StoreTransposedBlocks2(const V A, const V B, Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); +} + +} // namespace detail + +// >= 128 bit vector +template +HWY_API void StoreInterleaved2(const V v0, const V v1, Simd d, + T* HWY_RESTRICT unaligned) { + const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] + const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[N/2] v0[N/2] + detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +} + +// <= 64 bits +template +HWY_API void StoreInterleaved2(const V part0, const V part1, Simd d, + T* HWY_RESTRICT unaligned) { + const Twice d2; + const auto v0 = ZeroExtendVector(d2, part0); + const auto v1 = ZeroExtendVector(d2, part1); + const auto v10 = InterleaveLower(d2, v0, v1); + StoreU(v10, d2, unaligned); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void StoreTransposedBlocks3(const V A, const V B, const V C, + Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); + StoreU(C, d, unaligned + 2 * N); +} + +} // namespace detail + +// >= 128-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + using TU = TFromD; + const auto k5 = Set(du, TU{5}); + const auto k6 = Set(du, TU{6}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = LoadDup128(du, tbl_v0); + const auto shuf_A1 = LoadDup128(du, tbl_v1); // cannot reuse shuf_A0 (has 5) + const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const V A = BitCast(d, A0 | A1 | A2); + + // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. + const auto shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 + const auto shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B = BitCast(d, B0 | B1 | B2); + + // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] + const auto shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. + const auto shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. + const auto shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A + const auto C0 = TableLookupBytesOr0(v0, shuf_C0); + const auto C1 = TableLookupBytesOr0(v1, shuf_C1); + const auto C2 = TableLookupBytesOr0(v2, shuf_C2); + const V C = BitCast(d, C0 | C1 | C2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const Repartition du8; + const auto k2 = Set(du8, uint8_t{2 * sizeof(T)}); + const auto k3 = Set(du8, uint8_t{3 * sizeof(T)}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. Note that these are byte + // indices for 16-bit lanes. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A1 = LoadDup128(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const auto shuf_A2 = LoadDup128(du8, tbl_v2); // ..1..0.. + + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); + const V A = BitCast(d, A0 | A1 | A2); + + // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] + const auto shuf_B0 = shuf_A1 + k3; // 5..4..3. + const auto shuf_B1 = shuf_A2 + k3; // ..4..3.. + const auto shuf_B2 = shuf_A0 + k2; // .4..3..2 + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B = BitCast(d, B0 | B1 | B2); + + // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_C0 = shuf_B1 + k3; // ..7..6.. + const auto shuf_C1 = shuf_B2 + k3; // .7..6..5 + const auto shuf_C2 = shuf_B0 + k2; // 7..6..5. + const auto C0 = TableLookupBytesOr0(v0, shuf_C0); + const auto C1 = TableLookupBytesOr0(v1, shuf_C1); + const auto C2 = TableLookupBytesOr0(v2, shuf_C2); + const V C = BitCast(d, C0 | C1 | C2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + + const V v10_v00 = InterleaveLower(d, v0, v1); + const V v01_v20 = OddEven(v0, v2); + // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) + const V A = BitCast( + d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); + + const V v1_321 = ShiftRightLanes<1>(d, v1); + const V v0_32 = ShiftRightLanes<2>(d, v0); + const V v21_v11 = OddEven(v2, v1_321); + const V v12_v02 = OddEven(v1_321, v0_32); + // B: v1[2],v0[2], v2[1],v1[1] + const V B = BitCast( + d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); + + // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. + const V v23_v13 = OddEven(v2, v1_321); + const V v03_v22 = OddEven(v0, v2); + // C: v2[3],v1[3],v0[3], v2[2] + const V C = BitCast( + d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const V A = InterleaveLower(d, v0, v1); + const V B = OddEven(v0, v2); + const V C = InterleaveUpper(d, v1, v2); + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// 64-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(const V part0, const V part1, const V part2, + Full64 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors for the shuffles and first result. + const Full128 du; + const Full128 d_full; + const auto k5 = Set(du, uint8_t{5}); + const auto k6 = Set(du, uint8_t{6}); + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = Load(du, tbl_v0); + const auto shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) + const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned + 0 * N); + + // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_B0 = shuf_A2 + k6; // ..7..6.. + const auto shuf_B1 = shuf_A0 + k5; // .7..6..5 + const auto shuf_B2 = shuf_A1 + k5; // 7..6..5. + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B{(B0 | B1 | B2).raw}; + StoreU(B, d, unaligned + 1 * N); +} + +// 64-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(const Vec64 part0, const Vec64 part1, + const Vec64 part2, Full64 dh, + T* HWY_RESTRICT unaligned) { + const Full128 d; + const Full128 du8; + constexpr size_t N = 16 / sizeof(T); + const auto k2 = Set(du8, uint8_t{2 * sizeof(T)}); + const auto k3 = Set(du8, uint8_t{3 * sizeof(T)}); + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A1 = Load(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const auto shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); + const Vec128 A = BitCast(d, A0 | A1 | A2); + StoreU(A, d, unaligned + 0 * N); + + // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] + const auto shuf_B0 = shuf_A1 + k3; // ..3. + const auto shuf_B1 = shuf_A2 + k3; // .3.. + const auto shuf_B2 = shuf_A0 + k2; // 3..2 + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const Vec128 B = BitCast(d, B0 | B1 | B2); + StoreU(Vec64{B.raw}, dh, unaligned + 1 * N); +} + +// 64-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(const Vec64 v0, const Vec64 v1, + const Vec64 v2, Full64 d, + T* HWY_RESTRICT unaligned) { + // (same code as 128-bit vector, 64-bit lanes) + constexpr size_t N = 2; + const Vec64 v10_v00 = InterleaveLower(d, v0, v1); + const Vec64 v01_v20 = OddEven(v0, v2); + const Vec64 v21_v11 = InterleaveUpper(d, v1, v2); + StoreU(v10_v00, d, unaligned + 0 * N); + StoreU(v01_v20, d, unaligned + 1 * N); + StoreU(v21_v11, d, unaligned + 2 * N); +} + +// 64-bit lanes are handled by the N=1 case below. + +// <= 32-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(const Vec128 part0, + const Vec128 part1, + const Vec128 part2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du; + const Full128 d_full; + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, + 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = Load(du, tbl_v0); + const auto shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); + const auto shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. + const Vec128 A = BitCast(d_full, A0 | A1 | A2); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// 32-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(const Vec128 part0, + const Vec128 part1, + const Vec128 part2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 4 / sizeof(T); + // Use full vectors for the shuffles and result. + const Full128 du8; + const Full128 d_full; + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A2 = // ..1..0.. + Load(du8, tbl_v2); + const auto shuf_A1 = // ...1..0. + CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); + const auto shuf_A0 = // ....1..0 + CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// Single-element vector, any lane size: just store directly +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +// ------------------------------ StoreInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void StoreTransposedBlocks4(const V A, const V B, const V C, const V D, + Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); + StoreU(C, d, unaligned + 2 * N); + StoreU(D, d, unaligned + 3 * N); +} + +} // namespace detail + +// >= 128-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(const V v0, const V v1, const V v2, const V v3, + Simd d, T* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32L = ZipLower(dw, v2, v3); + const auto v10U = ZipUpper(dw, v0, v1); + const auto v32U = ZipUpper(dw, v2, v3); + // The interleaved vectors are A, B, C, D. + const auto A = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 + const auto B = BitCast(d, InterleaveUpper(dw, v10L, v32L)); + const auto C = BitCast(d, InterleaveLower(dw, v10U, v32U)); + const auto D = BitCast(d, InterleaveUpper(dw, v10U, v32U)); + detail::StoreTransposedBlocks4(A, B, C, D, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved4(const V v0, const V v1, const V v2, const V v3, + Simd d, T* HWY_RESTRICT unaligned) { + // The interleaved vectors are A, B, C, D. + const auto A = InterleaveLower(d, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d, v2, v3); + const auto C = InterleaveUpper(d, v0, v1); + const auto D = InterleaveUpper(d, v2, v3); + detail::StoreTransposedBlocks4(A, B, C, D, d, unaligned); +} + +// 64-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(const Vec64 part0, const Vec64 part1, + const Vec64 part2, const Vec64 part3, + Full64 /*tag*/, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const RepartitionToWide dw; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + const Vec128 v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); + const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); + StoreU(A, d_full, unaligned + 0 * N); + StoreU(B, d_full, unaligned + 1 * N); +} + +// 64-bit vector, 64-bit lane +template +HWY_API void StoreInterleaved4(const Vec64 part0, const Vec64 part1, + const Vec64 part2, const Vec64 part3, + Full64 /*tag*/, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + const Vec128 v3{part3.raw}; + const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d_full, v2, v3); + StoreU(A, d_full, unaligned + 0 * N); + StoreU(B, d_full, unaligned + 1 * N); +} + +// <= 32-bit vectors +template +HWY_API void StoreInterleaved4(const Vec128 part0, + const Vec128 part1, + const Vec128 part2, + const Vec128 part3, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const RepartitionToWide dw; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + const Vec128 v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(v3210, d_full, buf); + CopyBytes<4 * N * sizeof(T)>(buf, unaligned); +} + +#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED + +// ------------------------------ AESRound + +// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +// Define for white-box testing, even if native instructions are available. +namespace detail { + +// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with +// Vector Permute Instructions" and the accompanying assembly language +// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: +// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . +// +// A brute-force 256 byte table lookup can also be made constant-time, and +// possibly competitive on NEON, but this is more performance-portable +// especially for x86 and large vectors. +template // u8 +HWY_INLINE V SubBytes(V state) { + const DFromV du; + const auto mask = Set(du, uint8_t{0xF}); + + // Change polynomial basis to GF(2^4) + { + alignas(16) static constexpr uint8_t basisL[16] = { + 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, + 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA}; + alignas(16) static constexpr uint8_t basisU[16] = { + 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, + 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD}; + const auto sL = And(state, mask); + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto gf4L = TableLookupBytes(LoadDup128(du, basisL), sL); + const auto gf4U = TableLookupBytes(LoadDup128(du, basisU), sU); + state = Xor(gf4L, gf4U); + } + + // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and + // cause TableLookupBytesOr0 to return 0. + alignas(16) static constexpr uint8_t kZetaInv[16] = { + 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3}; + alignas(16) static constexpr uint8_t kInv[16] = { + 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4}; + const auto tbl = LoadDup128(du, kInv); + const auto sL = And(state, mask); // L=low nibble, U=upper + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto sX = Xor(sU, sL); + const auto invL = TableLookupBytes(LoadDup128(du, kZetaInv), sL); + const auto invU = TableLookupBytes(tbl, sU); + const auto invX = TableLookupBytes(tbl, sX); + const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); + const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); + + // Linear skew (cannot bake 0x63 bias into the table because out* indices + // may have the infinity flag set). + alignas(16) static constexpr uint8_t kAffineL[16] = { + 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, + 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15}; + alignas(16) static constexpr uint8_t kAffineU[16] = { + 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, + 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E}; + const auto affL = TableLookupBytesOr0(LoadDup128(du, kAffineL), outL); + const auto affU = TableLookupBytesOr0(LoadDup128(du, kAffineU), outU); + return Xor(Xor(affL, affU), Set(du, uint8_t{0x63})); +} + +} // namespace detail + +#endif // HWY_TARGET != HWY_SCALAR + +// "Include guard": skip if native AES instructions are available. +#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR + +namespace detail { + +template // u8 +HWY_API V ShiftRows(const V state) { + const DFromV du; + alignas(16) static constexpr uint8_t kShiftRow[16] = { + 0, 5, 10, 15, // transposed: state is column major + 4, 9, 14, 3, // + 8, 13, 2, 7, // + 12, 1, 6, 11}; + const auto shift_row = LoadDup128(du, kShiftRow); + return TableLookupBytes(state, shift_row); +} + +template // u8 +HWY_API V MixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. + // 1 2 3 1 // d are on diagonal, no permutation needed. + // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. + // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). + alignas(16) static constexpr uint8_t k2301[16] = { + 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}; + alignas(16) static constexpr uint8_t k1230[16] = { + 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12}; + const RebindToSigned di; // can only do signed comparisons + const auto msb = Lt(BitCast(di, state), Zero(di)); + const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, int8_t{0x1B}))); + const auto d = Xor(Add(state, state), overflow); // = state*2 in GF(2^8). + const auto s2301 = TableLookupBytes(state, LoadDup128(du, k2301)); + const auto d_s2301 = Xor(d, s2301); + const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} + const auto t1230_s3012 = TableLookupBytes(t_s2301, LoadDup128(du, k1230)); + return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms +} + +} // namespace detail + +template // u8 +HWY_API V AESRound(V state, const V round_key) { + // Intel docs swap the first two steps, but it does not matter because + // ShiftRows is a permutation and SubBytes is independent of lane index. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = detail::MixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +// Constant-time implementation inspired by +// https://www.bearssl.org/constanttime.html, but about half the cost because we +// use 64x64 multiplies and 128-bit XORs. +template +HWY_API V CLMulLower(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); + auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); + auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); + auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); + m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); + m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); + m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); + m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +template +HWY_API V CLMulUpper(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); + auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); + auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); + auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); + m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); + m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); + m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); + m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +#endif // HWY_NATIVE_AES +#endif // HWY_TARGET != HWY_SCALAR + +// "Include guard": skip if native POPCNT-related instructions are available. +#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +#undef HWY_MIN_POW2_FOR_128 +#if HWY_TARGET == HWY_RVV +#define HWY_MIN_POW2_FOR_128 1 +#else +// All other targets except HWY_SCALAR (which is excluded by HWY_IF_GE128_D) +// guarantee 128 bits anyway. +#define HWY_MIN_POW2_FOR_128 0 +#endif + +// This algorithm requires vectors to be at least 16 bytes, which is the case +// for LMUL >= 2. If not, use the fallback below. +template , HWY_IF_LANE_SIZE_D(D, 1), + HWY_IF_GE128_D(D), HWY_IF_POW2_GE(D, HWY_MIN_POW2_FOR_128)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint8_t>(), "V must be u8"); + const D d; + HWY_ALIGN constexpr uint8_t kLookup[16] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + }; + const auto lo = And(v, Set(d, uint8_t{0xF})); + const auto hi = ShiftRight<4>(v); + const auto lookup = LoadDup128(d, kLookup); + return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); +} + +// RVV has a specialization that avoids the Set(). +#if HWY_TARGET != HWY_RVV +// Slower fallback for capped vectors. +template , HWY_IF_LANE_SIZE_D(D, 1), + HWY_IF_LT128_D(D)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint8_t>(), "V must be u8"); + const D d; + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + const V k33 = Set(d, uint8_t{0x33}); + v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55}))); + v = Add(And(ShiftRight<2>(v), k33), And(v, k33)); + return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F})); +} +#endif // HWY_TARGET != HWY_RVV + +template , HWY_IF_LANE_SIZE_D(D, 2)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint16_t>(), "V must be u16"); + const D d; + const Repartition d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + return Add(ShiftRight<8>(vals), And(vals, Set(d, uint16_t{0xFF}))); +} + +template , HWY_IF_LANE_SIZE_D(D, 4)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint32_t>(), "V must be u32"); + const D d; + Repartition d16; + auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); + return Add(ShiftRight<16>(vals), And(vals, Set(d, uint32_t{0xFF}))); +} + +#if HWY_HAVE_INTEGER64 +template , HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint64_t>(), "V must be u64"); + const D d; + Repartition d32; + auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); + return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFFULL))); +} +#endif + +#endif // HWY_NATIVE_POPCNT + +template , HWY_IF_LANE_SIZE_D(D, 8), + HWY_IF_LT128_D(D), HWY_IF_FLOAT_D(D)> +HWY_API V operator*(V x, V y) { + return Set(D(), GetLane(x) * GetLane(y)); +} + +template , HWY_IF_LANE_SIZE_D(D, 8), + HWY_IF_LT128_D(D), HWY_IF_NOT_FLOAT_D(D)> +HWY_API V operator*(V x, V y) { + const DFromV d; + using T = TFromD; + using TU = MakeUnsigned; + const TU xu = static_cast(GetLane(x)); + const TU yu = static_cast(GetLane(y)); + return Set(d, static_cast(xu * yu)); +} + +// "Include guard": skip if native 64-bit mul instructions are available. +#if (defined(HWY_NATIVE_I64MULLO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +template , typename T = LaneType, + HWY_IF_LANE_SIZE(T, 8), HWY_IF_UNSIGNED(T), HWY_IF_GE128_D(D64)> +HWY_API V operator*(V x, V y) { + RepartitionToNarrow d32; + auto x32 = BitCast(d32, x); + auto y32 = BitCast(d32, y); + auto lolo = BitCast(d32, MulEven(x32, y32)); + auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); + auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); + auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); + return BitCast(D64{}, lolo + hi); +} +template , typename T = LaneType, + HWY_IF_LANE_SIZE(T, 8), HWY_IF_SIGNED(T), HWY_IF_GE128_D(DI64)> +HWY_API V operator*(V x, V y) { + RebindToUnsigned du64; + return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); +} + +#endif // HWY_NATIVE_I64MULLO + +// "Include guard": skip if native 8-bit compress instructions are available. +#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template +HWY_API size_t CompressBitsStore(V v, const uint8_t* HWY_RESTRICT bits, D d, + T* unaligned) { + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + const Simd d8; + T* HWY_RESTRICT pos = unaligned; + + HWY_ALIGN constexpr T table[2048] = { + 0, 1, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 1, 0, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 2, 0, 1, 3, 4, 5, 6, 7, /**/ 0, 2, 1, 3, 4, 5, 6, 7, // + 1, 2, 0, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 3, 0, 1, 2, 4, 5, 6, 7, /**/ 0, 3, 1, 2, 4, 5, 6, 7, // + 1, 3, 0, 2, 4, 5, 6, 7, /**/ 0, 1, 3, 2, 4, 5, 6, 7, // + 2, 3, 0, 1, 4, 5, 6, 7, /**/ 0, 2, 3, 1, 4, 5, 6, 7, // + 1, 2, 3, 0, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 4, 0, 1, 2, 3, 5, 6, 7, /**/ 0, 4, 1, 2, 3, 5, 6, 7, // + 1, 4, 0, 2, 3, 5, 6, 7, /**/ 0, 1, 4, 2, 3, 5, 6, 7, // + 2, 4, 0, 1, 3, 5, 6, 7, /**/ 0, 2, 4, 1, 3, 5, 6, 7, // + 1, 2, 4, 0, 3, 5, 6, 7, /**/ 0, 1, 2, 4, 3, 5, 6, 7, // + 3, 4, 0, 1, 2, 5, 6, 7, /**/ 0, 3, 4, 1, 2, 5, 6, 7, // + 1, 3, 4, 0, 2, 5, 6, 7, /**/ 0, 1, 3, 4, 2, 5, 6, 7, // + 2, 3, 4, 0, 1, 5, 6, 7, /**/ 0, 2, 3, 4, 1, 5, 6, 7, // + 1, 2, 3, 4, 0, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 5, 0, 1, 2, 3, 4, 6, 7, /**/ 0, 5, 1, 2, 3, 4, 6, 7, // + 1, 5, 0, 2, 3, 4, 6, 7, /**/ 0, 1, 5, 2, 3, 4, 6, 7, // + 2, 5, 0, 1, 3, 4, 6, 7, /**/ 0, 2, 5, 1, 3, 4, 6, 7, // + 1, 2, 5, 0, 3, 4, 6, 7, /**/ 0, 1, 2, 5, 3, 4, 6, 7, // + 3, 5, 0, 1, 2, 4, 6, 7, /**/ 0, 3, 5, 1, 2, 4, 6, 7, // + 1, 3, 5, 0, 2, 4, 6, 7, /**/ 0, 1, 3, 5, 2, 4, 6, 7, // + 2, 3, 5, 0, 1, 4, 6, 7, /**/ 0, 2, 3, 5, 1, 4, 6, 7, // + 1, 2, 3, 5, 0, 4, 6, 7, /**/ 0, 1, 2, 3, 5, 4, 6, 7, // + 4, 5, 0, 1, 2, 3, 6, 7, /**/ 0, 4, 5, 1, 2, 3, 6, 7, // + 1, 4, 5, 0, 2, 3, 6, 7, /**/ 0, 1, 4, 5, 2, 3, 6, 7, // + 2, 4, 5, 0, 1, 3, 6, 7, /**/ 0, 2, 4, 5, 1, 3, 6, 7, // + 1, 2, 4, 5, 0, 3, 6, 7, /**/ 0, 1, 2, 4, 5, 3, 6, 7, // + 3, 4, 5, 0, 1, 2, 6, 7, /**/ 0, 3, 4, 5, 1, 2, 6, 7, // + 1, 3, 4, 5, 0, 2, 6, 7, /**/ 0, 1, 3, 4, 5, 2, 6, 7, // + 2, 3, 4, 5, 0, 1, 6, 7, /**/ 0, 2, 3, 4, 5, 1, 6, 7, // + 1, 2, 3, 4, 5, 0, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 6, 0, 1, 2, 3, 4, 5, 7, /**/ 0, 6, 1, 2, 3, 4, 5, 7, // + 1, 6, 0, 2, 3, 4, 5, 7, /**/ 0, 1, 6, 2, 3, 4, 5, 7, // + 2, 6, 0, 1, 3, 4, 5, 7, /**/ 0, 2, 6, 1, 3, 4, 5, 7, // + 1, 2, 6, 0, 3, 4, 5, 7, /**/ 0, 1, 2, 6, 3, 4, 5, 7, // + 3, 6, 0, 1, 2, 4, 5, 7, /**/ 0, 3, 6, 1, 2, 4, 5, 7, // + 1, 3, 6, 0, 2, 4, 5, 7, /**/ 0, 1, 3, 6, 2, 4, 5, 7, // + 2, 3, 6, 0, 1, 4, 5, 7, /**/ 0, 2, 3, 6, 1, 4, 5, 7, // + 1, 2, 3, 6, 0, 4, 5, 7, /**/ 0, 1, 2, 3, 6, 4, 5, 7, // + 4, 6, 0, 1, 2, 3, 5, 7, /**/ 0, 4, 6, 1, 2, 3, 5, 7, // + 1, 4, 6, 0, 2, 3, 5, 7, /**/ 0, 1, 4, 6, 2, 3, 5, 7, // + 2, 4, 6, 0, 1, 3, 5, 7, /**/ 0, 2, 4, 6, 1, 3, 5, 7, // + 1, 2, 4, 6, 0, 3, 5, 7, /**/ 0, 1, 2, 4, 6, 3, 5, 7, // + 3, 4, 6, 0, 1, 2, 5, 7, /**/ 0, 3, 4, 6, 1, 2, 5, 7, // + 1, 3, 4, 6, 0, 2, 5, 7, /**/ 0, 1, 3, 4, 6, 2, 5, 7, // + 2, 3, 4, 6, 0, 1, 5, 7, /**/ 0, 2, 3, 4, 6, 1, 5, 7, // + 1, 2, 3, 4, 6, 0, 5, 7, /**/ 0, 1, 2, 3, 4, 6, 5, 7, // + 5, 6, 0, 1, 2, 3, 4, 7, /**/ 0, 5, 6, 1, 2, 3, 4, 7, // + 1, 5, 6, 0, 2, 3, 4, 7, /**/ 0, 1, 5, 6, 2, 3, 4, 7, // + 2, 5, 6, 0, 1, 3, 4, 7, /**/ 0, 2, 5, 6, 1, 3, 4, 7, // + 1, 2, 5, 6, 0, 3, 4, 7, /**/ 0, 1, 2, 5, 6, 3, 4, 7, // + 3, 5, 6, 0, 1, 2, 4, 7, /**/ 0, 3, 5, 6, 1, 2, 4, 7, // + 1, 3, 5, 6, 0, 2, 4, 7, /**/ 0, 1, 3, 5, 6, 2, 4, 7, // + 2, 3, 5, 6, 0, 1, 4, 7, /**/ 0, 2, 3, 5, 6, 1, 4, 7, // + 1, 2, 3, 5, 6, 0, 4, 7, /**/ 0, 1, 2, 3, 5, 6, 4, 7, // + 4, 5, 6, 0, 1, 2, 3, 7, /**/ 0, 4, 5, 6, 1, 2, 3, 7, // + 1, 4, 5, 6, 0, 2, 3, 7, /**/ 0, 1, 4, 5, 6, 2, 3, 7, // + 2, 4, 5, 6, 0, 1, 3, 7, /**/ 0, 2, 4, 5, 6, 1, 3, 7, // + 1, 2, 4, 5, 6, 0, 3, 7, /**/ 0, 1, 2, 4, 5, 6, 3, 7, // + 3, 4, 5, 6, 0, 1, 2, 7, /**/ 0, 3, 4, 5, 6, 1, 2, 7, // + 1, 3, 4, 5, 6, 0, 2, 7, /**/ 0, 1, 3, 4, 5, 6, 2, 7, // + 2, 3, 4, 5, 6, 0, 1, 7, /**/ 0, 2, 3, 4, 5, 6, 1, 7, // + 1, 2, 3, 4, 5, 6, 0, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 7, 0, 1, 2, 3, 4, 5, 6, /**/ 0, 7, 1, 2, 3, 4, 5, 6, // + 1, 7, 0, 2, 3, 4, 5, 6, /**/ 0, 1, 7, 2, 3, 4, 5, 6, // + 2, 7, 0, 1, 3, 4, 5, 6, /**/ 0, 2, 7, 1, 3, 4, 5, 6, // + 1, 2, 7, 0, 3, 4, 5, 6, /**/ 0, 1, 2, 7, 3, 4, 5, 6, // + 3, 7, 0, 1, 2, 4, 5, 6, /**/ 0, 3, 7, 1, 2, 4, 5, 6, // + 1, 3, 7, 0, 2, 4, 5, 6, /**/ 0, 1, 3, 7, 2, 4, 5, 6, // + 2, 3, 7, 0, 1, 4, 5, 6, /**/ 0, 2, 3, 7, 1, 4, 5, 6, // + 1, 2, 3, 7, 0, 4, 5, 6, /**/ 0, 1, 2, 3, 7, 4, 5, 6, // + 4, 7, 0, 1, 2, 3, 5, 6, /**/ 0, 4, 7, 1, 2, 3, 5, 6, // + 1, 4, 7, 0, 2, 3, 5, 6, /**/ 0, 1, 4, 7, 2, 3, 5, 6, // + 2, 4, 7, 0, 1, 3, 5, 6, /**/ 0, 2, 4, 7, 1, 3, 5, 6, // + 1, 2, 4, 7, 0, 3, 5, 6, /**/ 0, 1, 2, 4, 7, 3, 5, 6, // + 3, 4, 7, 0, 1, 2, 5, 6, /**/ 0, 3, 4, 7, 1, 2, 5, 6, // + 1, 3, 4, 7, 0, 2, 5, 6, /**/ 0, 1, 3, 4, 7, 2, 5, 6, // + 2, 3, 4, 7, 0, 1, 5, 6, /**/ 0, 2, 3, 4, 7, 1, 5, 6, // + 1, 2, 3, 4, 7, 0, 5, 6, /**/ 0, 1, 2, 3, 4, 7, 5, 6, // + 5, 7, 0, 1, 2, 3, 4, 6, /**/ 0, 5, 7, 1, 2, 3, 4, 6, // + 1, 5, 7, 0, 2, 3, 4, 6, /**/ 0, 1, 5, 7, 2, 3, 4, 6, // + 2, 5, 7, 0, 1, 3, 4, 6, /**/ 0, 2, 5, 7, 1, 3, 4, 6, // + 1, 2, 5, 7, 0, 3, 4, 6, /**/ 0, 1, 2, 5, 7, 3, 4, 6, // + 3, 5, 7, 0, 1, 2, 4, 6, /**/ 0, 3, 5, 7, 1, 2, 4, 6, // + 1, 3, 5, 7, 0, 2, 4, 6, /**/ 0, 1, 3, 5, 7, 2, 4, 6, // + 2, 3, 5, 7, 0, 1, 4, 6, /**/ 0, 2, 3, 5, 7, 1, 4, 6, // + 1, 2, 3, 5, 7, 0, 4, 6, /**/ 0, 1, 2, 3, 5, 7, 4, 6, // + 4, 5, 7, 0, 1, 2, 3, 6, /**/ 0, 4, 5, 7, 1, 2, 3, 6, // + 1, 4, 5, 7, 0, 2, 3, 6, /**/ 0, 1, 4, 5, 7, 2, 3, 6, // + 2, 4, 5, 7, 0, 1, 3, 6, /**/ 0, 2, 4, 5, 7, 1, 3, 6, // + 1, 2, 4, 5, 7, 0, 3, 6, /**/ 0, 1, 2, 4, 5, 7, 3, 6, // + 3, 4, 5, 7, 0, 1, 2, 6, /**/ 0, 3, 4, 5, 7, 1, 2, 6, // + 1, 3, 4, 5, 7, 0, 2, 6, /**/ 0, 1, 3, 4, 5, 7, 2, 6, // + 2, 3, 4, 5, 7, 0, 1, 6, /**/ 0, 2, 3, 4, 5, 7, 1, 6, // + 1, 2, 3, 4, 5, 7, 0, 6, /**/ 0, 1, 2, 3, 4, 5, 7, 6, // + 6, 7, 0, 1, 2, 3, 4, 5, /**/ 0, 6, 7, 1, 2, 3, 4, 5, // + 1, 6, 7, 0, 2, 3, 4, 5, /**/ 0, 1, 6, 7, 2, 3, 4, 5, // + 2, 6, 7, 0, 1, 3, 4, 5, /**/ 0, 2, 6, 7, 1, 3, 4, 5, // + 1, 2, 6, 7, 0, 3, 4, 5, /**/ 0, 1, 2, 6, 7, 3, 4, 5, // + 3, 6, 7, 0, 1, 2, 4, 5, /**/ 0, 3, 6, 7, 1, 2, 4, 5, // + 1, 3, 6, 7, 0, 2, 4, 5, /**/ 0, 1, 3, 6, 7, 2, 4, 5, // + 2, 3, 6, 7, 0, 1, 4, 5, /**/ 0, 2, 3, 6, 7, 1, 4, 5, // + 1, 2, 3, 6, 7, 0, 4, 5, /**/ 0, 1, 2, 3, 6, 7, 4, 5, // + 4, 6, 7, 0, 1, 2, 3, 5, /**/ 0, 4, 6, 7, 1, 2, 3, 5, // + 1, 4, 6, 7, 0, 2, 3, 5, /**/ 0, 1, 4, 6, 7, 2, 3, 5, // + 2, 4, 6, 7, 0, 1, 3, 5, /**/ 0, 2, 4, 6, 7, 1, 3, 5, // + 1, 2, 4, 6, 7, 0, 3, 5, /**/ 0, 1, 2, 4, 6, 7, 3, 5, // + 3, 4, 6, 7, 0, 1, 2, 5, /**/ 0, 3, 4, 6, 7, 1, 2, 5, // + 1, 3, 4, 6, 7, 0, 2, 5, /**/ 0, 1, 3, 4, 6, 7, 2, 5, // + 2, 3, 4, 6, 7, 0, 1, 5, /**/ 0, 2, 3, 4, 6, 7, 1, 5, // + 1, 2, 3, 4, 6, 7, 0, 5, /**/ 0, 1, 2, 3, 4, 6, 7, 5, // + 5, 6, 7, 0, 1, 2, 3, 4, /**/ 0, 5, 6, 7, 1, 2, 3, 4, // + 1, 5, 6, 7, 0, 2, 3, 4, /**/ 0, 1, 5, 6, 7, 2, 3, 4, // + 2, 5, 6, 7, 0, 1, 3, 4, /**/ 0, 2, 5, 6, 7, 1, 3, 4, // + 1, 2, 5, 6, 7, 0, 3, 4, /**/ 0, 1, 2, 5, 6, 7, 3, 4, // + 3, 5, 6, 7, 0, 1, 2, 4, /**/ 0, 3, 5, 6, 7, 1, 2, 4, // + 1, 3, 5, 6, 7, 0, 2, 4, /**/ 0, 1, 3, 5, 6, 7, 2, 4, // + 2, 3, 5, 6, 7, 0, 1, 4, /**/ 0, 2, 3, 5, 6, 7, 1, 4, // + 1, 2, 3, 5, 6, 7, 0, 4, /**/ 0, 1, 2, 3, 5, 6, 7, 4, // + 4, 5, 6, 7, 0, 1, 2, 3, /**/ 0, 4, 5, 6, 7, 1, 2, 3, // + 1, 4, 5, 6, 7, 0, 2, 3, /**/ 0, 1, 4, 5, 6, 7, 2, 3, // + 2, 4, 5, 6, 7, 0, 1, 3, /**/ 0, 2, 4, 5, 6, 7, 1, 3, // + 1, 2, 4, 5, 6, 7, 0, 3, /**/ 0, 1, 2, 4, 5, 6, 7, 3, // + 3, 4, 5, 6, 7, 0, 1, 2, /**/ 0, 3, 4, 5, 6, 7, 1, 2, // + 1, 3, 4, 5, 6, 7, 0, 2, /**/ 0, 1, 3, 4, 5, 6, 7, 2, // + 2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 2, 3, 4, 5, 6, 7, 1, // + 1, 2, 3, 4, 5, 6, 7, 0, /**/ 0, 1, 2, 3, 4, 5, 6, 7}; + + for (size_t i = 0; i < Lanes(d); i += 8) { + // Each byte worth of bits is the index of one of 256 8-byte ranges, and its + // population count determines how far to advance the write position. + const size_t bits8 = bits[i / 8]; + const auto indices = Load(d8, table + bits8 * 8); + const auto compressed = TableLookupBytes(LoadU(d8, lanes + i), indices); + StoreU(compressed, d8, pos); + pos += PopCount(bits8); + } + return static_cast(pos - unaligned); +} + +template +HWY_API size_t CompressStore(V v, M mask, D d, T* HWY_RESTRICT unaligned) { + uint8_t bits[HWY_MAX(size_t{8}, MaxLanes(d) / 8)]; + (void)StoreMaskBits(d, mask, bits); + return CompressBitsStore(v, bits, d, unaligned); +} + +template +HWY_API size_t CompressBlendedStore(V v, M mask, D d, + T* HWY_RESTRICT unaligned) { + HWY_ALIGN T buf[MaxLanes(d)]; + const size_t bytes = CompressStore(v, mask, d, buf); + BlendedStore(Load(d, buf), FirstN(d, bytes), d, unaligned); + return bytes; +} + +// For reasons unknown, HWY_IF_LANE_SIZE_V is a compile error in SVE. +template , HWY_IF_LANE_SIZE(T, 1)> +HWY_API V Compress(V v, const M mask) { + const DFromV d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +template , HWY_IF_LANE_SIZE(T, 1)> +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + const DFromV d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressBitsStore(v, bits, d, lanes); + return Load(d, lanes); +} + +template , HWY_IF_LANE_SIZE(T, 1)> +HWY_API V CompressNot(V v, M mask) { + return Compress(v, Not(mask)); +} + +#endif // HWY_NATIVE_COMPRESS8 + +// ================================================== Operator wrapper + +// These targets currently cannot define operators and have already defined +// (only) the corresponding functions such as Add. +#if HWY_TARGET != HWY_RVV && HWY_TARGET != HWY_SVE && \ + HWY_TARGET != HWY_SVE2 && HWY_TARGET != HWY_SVE_256 && \ + HWY_TARGET != HWY_SVE2_128 + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Ne(V a, V b) -> decltype(a == b) { + return a != b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +#endif // HWY_TARGET for operators + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/rvv-inl.h b/third_party/highway/hwy/ops/rvv-inl.h new file mode 100644 index 0000000000..502611282c --- /dev/null +++ b/third_party/highway/hwy/ops/rvv-inl.h @@ -0,0 +1,3451 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RISC-V V vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +// Enables the overload if Pow2 is in [min, max]. +#define HWY_RVV_IF_POW2_IN(D, min, max) \ + hwy::EnableIf<(min) <= Pow2(D()) && Pow2(D()) <= (max)>* = nullptr + +template +constexpr size_t MLenFromD(Simd /* tag */) { + // Returns divisor = type bits / LMUL. Folding *8 into the ScaleByPower + // argument enables fractional LMUL < 1. Limit to 64 because that is the + // largest value for which vbool##_t are defined. + return HWY_MIN(64, sizeof(T) * 8 * 8 / detail::ScaleByPower(8, kPow2)); +} + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// For all mask sizes MLEN: (1/Nth of a register, one bit per lane) +// The first two arguments are SEW and SHIFT such that SEW >> SHIFT = MLEN. +#define HWY_RVV_FOREACH_B(X_MACRO, NAME, OP) \ + X_MACRO(64, 0, 64, NAME, OP) \ + X_MACRO(32, 0, 32, NAME, OP) \ + X_MACRO(16, 0, 16, NAME, OP) \ + X_MACRO(8, 0, 8, NAME, OP) \ + X_MACRO(8, 1, 4, NAME, OP) \ + X_MACRO(8, 2, 2, NAME, OP) \ + X_MACRO(8, 3, 1, NAME, OP) + +// For given SEW, iterate over one of LMULS: _TRUNC, _EXT, _ALL. This allows +// reusing type lists such as HWY_RVV_FOREACH_U for _ALL (the usual case) or +// _EXT (for Combine). To achieve this, we HWY_CONCAT with the LMULS suffix. +// +// Precompute SEW/LMUL => MLEN to allow token-pasting the result. For the same +// reason, also pass the double-width and half SEW and LMUL (suffixed D and H, +// respectively). "__" means there is no corresponding LMUL (e.g. LMULD for m8). +// Args: BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP + +// LMULS = _TRUNC: truncatable (not the smallest LMUL) +#define HWY_RVV_FOREACH_08_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. +#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _LE2: <= 2 +#define HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) + +// LMULS = _EXT: not the largest LMUL +#define HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) + +// LMULS = _ALL (2^MinPow2() <= LMUL <= 8) +#define HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// 'Virtual' LMUL. This upholds the Highway guarantee that vectors are at least +// 128 bit and LowerHalf is defined whenever there are at least 2 lanes, even +// though RISC-V LMUL must be at least SEW/64 (notice that this rules out +// LMUL=1/2 for SEW=64). To bridge the gap, we add overloads for kPow2 equal to +// one less than should be supported, with all other parameters (vector type +// etc.) unchanged. For D with the lowest kPow2 ('virtual LMUL'), Lanes() +// returns half of what it usually would. +// +// Notice that we can only add overloads whenever there is a D argument: those +// are unique with respect to non-virtual-LMUL overloads because their kPow2 +// template argument differs. Otherwise, there is no actual vuint64mf2_t, and +// defining another overload with the same LMUL would be an error. Thus we have +// a separate _VIRT category for HWY_RVV_FOREACH*, and the common case is +// _ALL_VIRT (meaning the regular LMUL plus the VIRT overloads), used in most +// functions that take a D. + +#define HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, -1, /*MLEN=*/64, NAME, OP) + +// ALL + VIRT +#define HWY_RVV_FOREACH_08_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// LE2 + VIRT +#define HWY_RVV_FOREACH_08_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// EXT + VIRT +#define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// DEMOTE + VIRT +#define HWY_RVV_FOREACH_08_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// SEW for unsigned: +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, uint, u, NAME, OP) + +// SEW for signed: +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, int, i, NAME, OP) + +// SEW for float: +#if HWY_HAVE_FLOAT16 +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, float, f, NAME, OP) +#else +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#endif +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, float, f, NAME, OP) + +// Commonly used type/SEW groups: +#define HWY_RVV_FOREACH_UI08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) + +// For all combinations of SEW: +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) + +// Commonly used type categories: +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) + +// Assemble types for use in x-macros +#define HWY_RVV_T(BASE, SEW) BASE##SEW##_t +#define HWY_RVV_D(BASE, SEW, N, SHIFT) Simd +#define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##LMUL##_t +#define HWY_RVV_M(MLEN) vbool##MLEN##_t + +} // namespace detail + +// Until we have full intrinsic support for fractional LMUL, mixed-precision +// code can use LMUL 1..8 (adequate unless they need many registers). +#define HWY_SPECIALIZE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using Lane = HWY_RVV_T(BASE, SEW); \ + using type = ScalableTag; \ + }; + +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) +#undef HWY_SPECIALIZE + +// ------------------------------ Lanes + +// WARNING: we want to query VLMAX/sizeof(T), but this actually changes VL! +// vlenb is not exposed through intrinsics and vreadvl is not VLMAX. +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + size_t actual = v##OP##SEW##LMUL(); \ + /* Common case of full vectors: avoid any extra instructions. */ \ + /* actual includes LMUL, so do not shift again. */ \ + if (detail::IsFull(d)) return actual; \ + /* Check for virtual LMUL, e.g. "uint16mf8_t" (not provided by */ \ + /* intrinsics). In this case the actual LMUL is 1/4, so divide by */ \ + /* another factor of two. */ \ + if (detail::ScaleByPower(128 / SEW, SHIFT) == 1) actual >>= 1; \ + return HWY_MIN(actual, N); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL_VIRT) +#undef HWY_RVV_LANES + +template +HWY_API size_t Lanes(Simd /* tag*/) { + return Lanes(Simd()); +} + +// ------------------------------ Common x-macros + +// Last argument to most intrinsics. Use when the op has no d arg of its own, +// which means there is no user-specified cap. +#define HWY_RVV_AVL(SEW, SHIFT) \ + Lanes(ScalableTag()) + +// vector = f(vector), e.g. Not +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, scalar), e.g. detail::AddS +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask) +#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return vm##OP##_m_b##MLEN(m, ~0ull); \ + } + +// ================================================== INIT + +// ------------------------------ Set + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_T(BASE, SEW) arg) { \ + return v##OP##_##CHAR##SEW##LMUL(arg, Lanes(d)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL_VIRT) +#undef HWY_RVV_SET + +// Treat bfloat16_t as uint16_t (using the previously defined Set overloads); +// required for Zero and VFromD. +template +decltype(Set(Simd(), 0)) Set(Simd d, + bfloat16_t arg) { + return Set(RebindToUnsigned(), arg.bits); +} + +template +using VFromD = decltype(Set(D(), TFromD())); + +// ------------------------------ Zero + +template +HWY_API VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +// RVV vundefined is 'poisoned' such that even XORing a _variable_ initialized +// by it gives unpredictable results. It should only be used for maskoff, so +// keep it internal. For the Highway op, just use Zero (single instruction). +namespace detail { +#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) /* tag */) { \ + return v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ + } + +HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined, _ALL) +#undef HWY_RVV_UNDEFINED +} // namespace detail + +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ BitCast + +namespace detail { + +// Halves LMUL. (Use LMUL arg for the source so we can use _TRUNC.) +#define HWY_RVV_TRUNC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +#undef HWY_RVV_TRUNC + +// Doubles LMUL to `d2` (the arg is only necessary for _VIRT). +#define HWY_RVV_EXT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULD) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULD(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) +#undef HWY_RVV_EXT + +// For virtual LMUL e.g. 'uint32mf4_t', the return type should be mf2, which is +// the same as the actual input type. +#define HWY_RVV_EXT_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT_VIRT, Ext, lmul_ext, _VIRT) +#undef HWY_RVV_EXT_VIRT + +// For BitCastToByte, the D arg is only to prevent duplicate definitions caused +// by _ALL_VIRT. + +// There is no reinterpret from u8 <-> u8, so just return. +#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vuint8##LMUL##_t v) { \ + return v; \ + } \ + template \ + HWY_API vuint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v; \ + } + +// For i8, need a single reinterpret (HWY_RVV_CAST_IF does two). +#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vint8##LMUL##_t v) { \ + return vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API vint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ + } + +// Separate u/i because clang only provides signed <-> unsigned reinterpret for +// the same SEW. +#define HWY_RVV_CAST_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ + } + +// Additional versions for virtual LMUL using LMULH for byte vectors. +#define HWY_RVV_CAST_VIRT_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v2); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_VIRT_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v))); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v2)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_VIRT_U, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) + +#undef HWY_RVV_CAST_U8 +#undef HWY_RVV_CAST_I8 +#undef HWY_RVV_CAST_U +#undef HWY_RVV_CAST_IF +#undef HWY_RVV_CAST_VIRT_U +#undef HWY_RVV_CAST_VIRT_IF + +template +HWY_INLINE VFromD> BitCastFromByte( + Simd /* d */, VFromD> v) { + return BitCastFromByte(Simd(), v); +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(d, v)); +} + +namespace detail { + +template >> +HWY_INLINE VFromD BitCastToUnsigned(V v) { + return BitCast(DU(), v); +} + +} // namespace detail + +// ------------------------------ Iota + +namespace detail { + +#define HWY_RVV_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + return v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v, _ALL_VIRT) +#undef HWY_RVV_IOTA + +template > +HWY_INLINE VFromD Iota0(const D /*d*/) { + return BitCastToUnsigned(Iota0(DU())); +} + +} // namespace detail + +// ================================================== LOGICAL + +// ------------------------------ Not + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not, _ALL) + +template +HWY_API V Not(const V v) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Not(BitCast(DU(), v))); +} + +// ------------------------------ And + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and, _ALL) + +template +HWY_API V And(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), And(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Or + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or, _ALL) + +template +HWY_API V Or(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Or(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Xor + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor, _ALL) + +template +HWY_API V Xor(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Xor(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ AndNot +template +HWY_API V AndNot(const V not_a, const V b) { + return And(Not(not_a), b); +} + +// ------------------------------ Xor3 +template +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ CopySign + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj, _ALL) + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + // RVV can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, ReverseSubS, rsub_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, ReverseSubS, frsub_vf, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd, _ALL) + +// ------------------------------ Sub +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) + +// ------------------------------ SaturatedAdd + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) + +// ------------------------------ SaturatedSub + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) + +// ------------------------------ AverageRound + +// TODO(janwas): check vxrm rounding mode +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) + +// ------------------------------ ShiftLeft[Same] + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(v, kBits, HWY_RVV_AVL(SEW, SHIFT)); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(v, static_cast(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll, _ALL) + +// ------------------------------ ShiftRight[Same] + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) + +#undef HWY_RVV_SHIFT + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API VFromD>> SumsOf8(const VU8 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = detail::AndS(BitCast(du16, v), 0xFF); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); +} + +// ------------------------------ RotateRight +template +HWY_API V RotateRight(const V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shl +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, bits, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll, _ALL) + +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, detail::BitCastToUnsigned(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll, _ALL) + +// ------------------------------ Shr + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) + +#undef HWY_RVV_SHIFT_II +#undef HWY_RVV_SHIFT_VV + +// ------------------------------ Min + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin, _ALL) + +// ------------------------------ Max + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax, _ALL) + +// ------------------------------ Mul + +HWY_RVV_FOREACH_UI163264(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// ------------------------------ MulHigh + +// Only for internal use (Highway only promises MulHigh for 16-bit inputs). +// Used by MulEven; vwmul does not work for m8. +namespace detail { +HWY_RVV_FOREACH_I32(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) +HWY_RVV_FOREACH_U32(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) + +// ------------------------------ MulFixedPoint15 +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulFixedPoint15, smul, _ALL) + +// ------------------------------ Div +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) + +// ------------------------------ ApproximateReciprocal +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7, _ALL) + +// ------------------------------ Sqrt +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt, _ALL) + +// ------------------------------ ApproximateReciprocalSqrt +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7, _ALL) + +// ------------------------------ MulAdd +// Note: op is still named vv, not vvv. +#define HWY_RVV_FMA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ + HWY_RVV_V(BASE, SEW, LMUL) add) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(add, mul, x, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc, _ALL) + +// ------------------------------ NegMulAdd +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac, _ALL) + +// ------------------------------ MulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac, _ALL) + +// ------------------------------ NegMulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) + +#undef HWY_RVV_FMA + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in +// vboolXX_t is a power of two divisor for vector bits. SLEN 8 / LMUL 1 = 1/8th +// of all bits; SLEN 8 / LMUL 4 = half of all bits. + +// mask = f(vector, vector) +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN(a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(vector, scalar) +#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL##_b##MLEN(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// ------------------------------ Eq +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL) +} // namespace detail + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL) +} // namespace detail + +// ------------------------------ Lt +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, LtS, msltu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL) +} // namespace detail + +// ------------------------------ Le +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL) + +#undef HWY_RVV_RETM_ARGVV +#undef HWY_RVV_RETM_ARGVS + +// ------------------------------ Gt/Ge + +template +HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) { + return Le(b, a); +} + +template +HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { + return Lt(b, a); +} + +// ------------------------------ TestBit +template +HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { + return detail::NeS(And(a, bit), 0); +} + +// ------------------------------ Not +// NOLINTNEXTLINE +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, Not, not ) + +// ------------------------------ And + +// mask = f(mask_a, mask_b) (note arg2,arg1 order!) +#define HWY_RVV_RETM_ARGMM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) a, HWY_RVV_M(MLEN) b) { \ + return vm##OP##_mm_b##MLEN(b, a, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, And, and) + +// ------------------------------ AndNot +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, AndNot, andn) + +// ------------------------------ Or +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Or, or) + +// ------------------------------ Xor +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) + +// ------------------------------ ExclusiveNeither +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, ExclusiveNeither, xnor) + +#undef HWY_RVV_RETM_ARGMM + +// ------------------------------ IfThenElse +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ + HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_vvm_##CHAR##SEW##LMUL(no, yes, m, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge, _ALL) + +#undef HWY_RVV_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template +HWY_API V IfThenElseZero(const M mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse + +#define HWY_RVV_IF_THEN_ZERO_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_##CHAR##SEW##LMUL(no, 0, m, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, merge_vxm, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) + +#undef HWY_RVV_IF_THEN_ZERO_ELSE + +// ------------------------------ MaskFromVec + +template +HWY_API auto MaskFromVec(const V v) -> decltype(Eq(v, v)) { + return detail::NeS(v, 0); +} + +template +using MFromD = decltype(MaskFromVec(Zero(D()))); + +template +HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { + // No need to check lane size/LMUL are the same: if not, casting MFrom to + // MFromD would fail. + return mask; +} + +// ------------------------------ VecFromMask + +namespace detail { +#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_M(MLEN) m) { \ + return v##OP##_##CHAR##SEW##LMUL##_m(m, v0, v0, 1, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, SubS, sub_vx, _ALL) +#undef HWY_RVV_VEC_FROM_MASK +} // namespace detail + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return detail::SubS(Zero(d), mask); +} + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec) + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ ZeroIfNegative +template +HWY_API V ZeroIfNegative(const V v) { + return IfThenZeroElse(detail::LtS(v, 0), v); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + MFromD m = + MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ FindFirstTrue + +#define HWY_RVV_FIND_FIRST_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API intptr_t FindFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return vfirst_m_b##MLEN(m, Lanes(d)); \ + } \ + template \ + HWY_API size_t FindKnownFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return static_cast(vfirst_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_FIND_FIRST_TRUE, , _) +#undef HWY_RVV_FIND_FIRST_TRUE + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, MFromD m) { + return FindFirstTrue(d, m) < 0; +} + +// ------------------------------ AllTrue + +#define HWY_RVV_ALL_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API bool AllTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return AllFalse(d, vmnot_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) +#undef HWY_RVV_ALL_TRUE + +// ------------------------------ CountTrue + +#define HWY_RVV_COUNT_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t CountTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return vcpop_m_b##MLEN(m, Lanes(d)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) +#undef HWY_RVV_COUNT_TRUE + +// ================================================== MEMORY + +// ------------------------------ Load + +#define HWY_RVV_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL_VIRT) +#undef HWY_RVV_LOAD + +// There is no native BF16, treat as uint16_t. +template +HWY_API VFromD> Load( + Simd d, const bfloat16_t* HWY_RESTRICT p) { + return Load(RebindToUnsigned(), + reinterpret_cast(p)); +} + +template +HWY_API void Store(VFromD> v, + Simd d, bfloat16_t* HWY_RESTRICT p) { + Store(v, RebindToUnsigned(), + reinterpret_cast(p)); +} + +// ------------------------------ LoadU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ MaskedLoad + +#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, Zero(d), p, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) +#undef HWY_RVV_MASKED_LOAD + +// ------------------------------ Store + +#define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) +#undef HWY_RVV_STORE + +// ------------------------------ BlendedStore + +#define HWY_RVV_BLENDED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, p, v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_BLENDED_STORE, BlendedStore, se, _ALL_VIRT) +#undef HWY_RVV_BLENDED_STORE + +namespace detail { + +#define HWY_RVV_STOREN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, count); \ + } +HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) +#undef HWY_RVV_STOREN + +} // namespace detail + +// ------------------------------ StoreU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ Stream +template +HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ ScatterOffset + +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL_VIRT) +#undef HWY_RVV_SCATTER + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + const VFromD> index) { + return ScatterOffset(v, d, base, ShiftLeft<2>(index)); +} + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + const VFromD> index) { + return ScatterOffset(v, d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ GatherOffset + +#define HWY_RVV_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL_VIRT) +#undef HWY_RVV_GATHER + +// ------------------------------ GatherIndex + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + return GatherOffset(d, base, ShiftLeft<2>(index)); +} + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + return GatherOffset(d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_RVV_LOAD2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, unaligned, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD2, LoadInterleaved2, lseg2, _LE2_VIRT) +#undef HWY_RVV_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_RVV_LOAD3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, &v2, unaligned, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD3, LoadInterleaved3, lseg3, _LE2_VIRT) +#undef HWY_RVV_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_RVV_LOAD4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2, HWY_RVV_V(BASE, SEW, LMUL) & v3) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, &v2, &v3, aligned, \ + Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD4, LoadInterleaved4, lseg4, _LE2_VIRT) +#undef HWY_RVV_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_RVV_STORE2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(unaligned, v0, v1, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE2, StoreInterleaved2, sseg2, _LE2_VIRT) +#undef HWY_RVV_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(unaligned, v0, v1, v2, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE3, StoreInterleaved3, sseg3, _LE2_VIRT) +#undef HWY_RVV_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_RVV_STORE4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(aligned, v0, v1, v2, v3, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE4, StoreInterleaved4, sseg4, _LE2_VIRT) +#undef HWY_RVV_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// SEW is for the input so we can use F16 (no-op if not supported). +#define HWY_RVV_PROMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWD##LMULD(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I08(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_F16(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT_VIRT) +HWY_RVV_FOREACH_F32(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT_VIRT) +#undef HWY_RVV_PROMOTE + +// The above X-macro cannot handle 4x promotion nor type switching. +// TODO(janwas): use BASE2 arg to allow the latter. +#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN, \ + SHIFT, ADD) \ + template \ + HWY_API HWY_RVV_V(BASE, BITS, LMUL) \ + PromoteTo(HWY_RVV_D(BASE, BITS, N, SHIFT + ADD) d, \ + HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ + return OP##CHAR##BITS##LMUL(v, Lanes(d)); \ + } + +#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -2, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1, 0, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2, 1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4, 2, 1) + +#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, mf2, mf8, -3, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4, -2, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2, -1, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1, 0, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2, 1, 2) + +HWY_RVV_PROMOTE_X4(vzext_vf4_, uint, u, 32, uint, 8) +HWY_RVV_PROMOTE_X4(vsext_vf4_, int, i, 32, int, 8) + +// i32 to f64 +HWY_RVV_PROMOTE_X2(vfwcvt_f_x_v_, float, f, 64, int, 32) + +#undef HWY_RVV_PROMOTE_X4 +#undef HWY_RVV_PROMOTE_X2 +#undef HWY_RVV_PROMOTE + +// Unsigned to signed: cast for unsigned promote. +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + const RebindToSigned di32; + const Rebind du16; + return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ DemoteTo U + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 0, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME##Shr16( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + +// Unsigned -> unsigned (also used for bf16) +namespace detail { +HWY_RVV_FOREACH_U16(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE_VIRT) +} // namespace detail + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_I_TO_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(uint, SEWH, LMULH) NAME( \ + HWY_RVV_D(uint, SEWH, N, SHIFT - 1) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + /* First clamp negative numbers to zero to match x86 packus. */ \ + return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); \ + } +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_I_TO_U + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return vnclipu_wx_u8mf8(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return vnclipu_wx_u8mf4(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return vnclipu_wx_u8mf2(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return vnclipu_wx_u8m1(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return vnclipu_wx_u8m2(DemoteTo(Simd(), v), 0, Lanes(d)); +} + +HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf8(vnclipu_wx_u16mf4(v, 0, avl), 0, avl); +} +HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf4(vnclipu_wx_u16mf2(v, 0, avl), 0, avl); +} +HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf2(vnclipu_wx_u16m1(v, 0, avl), 0, avl); +} +HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8m1(vnclipu_wx_u16m2(v, 0, avl), 0, avl); +} +HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8m2(vnclipu_wx_u16m4(v, 0, avl), 0, avl); +} + +// ------------------------------ Truncations + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFF, avl); + const vuint32mf2_t v2 = vnclipu_wx_u32mf2(v1, 0, avl); + const vuint16mf4_t v3 = vnclipu_wx_u16mf4(v2, 0, avl); + return vnclipu_wx_u8mf8(v3, 0, avl); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFF, avl); + const vuint32m1_t v2 = vnclipu_wx_u32m1(v1, 0, avl); + const vuint16mf2_t v3 = vnclipu_wx_u16mf2(v2, 0, avl); + return vnclipu_wx_u8mf4(v3, 0, avl); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFF, avl); + const vuint32m2_t v2 = vnclipu_wx_u32m2(v1, 0, avl); + const vuint16m1_t v3 = vnclipu_wx_u16m1(v2, 0, avl); + return vnclipu_wx_u8mf2(v3, 0, avl); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFF, avl); + const vuint32m4_t v2 = vnclipu_wx_u32m4(v1, 0, avl); + const vuint16m2_t v3 = vnclipu_wx_u16m2(v2, 0, avl); + return vnclipu_wx_u8m1(v3, 0, avl); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = vnclipu_wx_u32mf2(v1, 0, avl); + return vnclipu_wx_u16mf4(v2, 0, avl); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFFFF, avl); + const vuint32m1_t v2 = vnclipu_wx_u32m1(v1, 0, avl); + return vnclipu_wx_u16mf2(v2, 0, avl); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFFFF, avl); + const vuint32m2_t v2 = vnclipu_wx_u32m2(v1, 0, avl); + return vnclipu_wx_u16m1(v2, 0, avl); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFFFF, avl); + const vuint32m4_t v2 = vnclipu_wx_u32m4(v1, 0, avl); + return vnclipu_wx_u16m2(v2, 0, avl); +} + +template +HWY_API vuint32mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32mf2(v1, 0, avl); +} + +template +HWY_API vuint32m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m1(v1, 0, avl); +} + +template +HWY_API vuint32m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m2(v1, 0, avl); +} + +template +HWY_API vuint32m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m4(v1, 0, avl); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = vand(v, 0xFF, avl); + const vuint16mf4_t v2 = vnclipu_wx_u16mf4(v1, 0, avl); + return vnclipu_wx_u8mf8(v2, 0, avl); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = vand(v, 0xFF, avl); + const vuint16mf2_t v2 = vnclipu_wx_u16mf2(v1, 0, avl); + return vnclipu_wx_u8mf4(v2, 0, avl); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = vand(v, 0xFF, avl); + const vuint16m1_t v2 = vnclipu_wx_u16m1(v1, 0, avl); + return vnclipu_wx_u8mf2(v2, 0, avl); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = vand(v, 0xFF, avl); + const vuint16m2_t v2 = vnclipu_wx_u16m2(v1, 0, avl); + return vnclipu_wx_u8m1(v2, 0, avl); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = vand(v, 0xFF, avl); + const vuint16m4_t v2 = vnclipu_wx_u16m4(v1, 0, avl); + return vnclipu_wx_u8m2(v2, 0, avl); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16mf4(v1, 0, avl); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16mf2(v1, 0, avl); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m1(v1, 0, avl); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m2(v1, 0, avl); +} + +template +HWY_API vuint16m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m4(v1, 0, avl); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf4_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf8(v1, 0, avl); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf2_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf4(v1, 0, avl); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m1_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf2(v1, 0, avl); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m2_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m1(v1, 0, avl); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m4_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m2(v1, 0, avl); +} + +template +HWY_API vuint8m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m8_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m4(v1, 0, avl); +} + +// ------------------------------ DemoteTo I + +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE_VIRT) + +template +HWY_API vint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} + +#undef HWY_RVV_DEMOTE + +// ------------------------------ DemoteTo F + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_F(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##SEWH##LMULH(v, Lanes(d)); \ + } + +#if HWY_HAVE_FLOAT16 +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, + _DEMOTE_VIRT) +#endif +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, + _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_F + +// TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { + return vfncvt_rtz_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { + return vfncvt_rtz_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { + return vfncvt_rtz_x_f_w_i32m4(v, Lanes(d)); +} + +template +HWY_API VFromD> DemoteTo( + Simd d, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du32; + return detail::DemoteToShr16(du16, BitCast(du32, v)); +} + +// ------------------------------ ConvertTo F + +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + return vfcvt_f_x_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(uint, SEW, LMUL) v) {\ + return vfcvt_f_xu_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(int, SEW, N, SHIFT) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_rtz_x_f_v_i##SEW##LMUL(v, Lanes(d)); \ + } \ +// API only requires f32 but we provide f64 for internal use. +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) +#undef HWY_RVV_CONVERT + +// Uses default rounding mode. Must be separate because there is no D arg. +#define HWY_RVV_NEAREST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) NearestInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_x_f_v_i##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_NEAREST, _, _, _ALL) +#undef HWY_RVV_NEAREST + +// ================================================== COMBINE + +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +size_t LanesPerBlock(Simd d) { + size_t lpb = 16 / sizeof(T); + if (IsFull(d)) return lpb; + // Also honor the user-specified (constexpr) N limit. + lpb = HWY_MIN(lpb, N); + // No fraction, we're done. + if (kPow2 >= 0) return lpb; + // Fractional LMUL: Lanes(d) may be smaller than lpb, so honor that. + return HWY_MIN(lpb, Lanes(d)); +} + +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return AndS(iota0, static_cast(~(LanesPerBlock(d) - 1))); +} + +template +HWY_INLINE MFromD FirstNPerBlock(D /* tag */) { + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; + const auto idx_mod = AndS(Iota0(du), static_cast(LanesPerBlock(du) - 1)); + return LtS(BitCast(di, idx_mod), static_cast>(kLanes)); +} + +// vector = f(vector, vector, size_t) +#define HWY_RVV_SLIDE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dst, HWY_RVV_V(BASE, SEW, LMUL) src, \ + size_t lanes) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(dst, src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideUp, slideup, _ALL) +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideDown, slidedown, _ALL) + +#undef HWY_RVV_SLIDE + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { + return detail::SlideUp(lo, hi, Lanes(d) / 2); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { + // Move upper half into lower + const auto lo_down = detail::SlideDown(lo, lo, Lanes(d) / 2); + return ConcatUpperLower(d, hi, lo_down); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { + // Move half of both inputs to the other half + const auto hi_up = detail::SlideUp(hi, hi, Lanes(d) / 2); + const auto lo_down = detail::SlideDown(lo, lo, Lanes(d) / 2); + return ConcatUpperLower(d, hi_up, lo_down); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { + return detail::SlideUp(detail::Ext(d2, lo), detail::Ext(d2, hi), + Lanes(d2) / 2); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API VFromD ZeroExtendVector(D2 d2, const V lo) { + return Combine(d2, Xor(lo, lo), lo); +} + +// ------------------------------ Lower/UpperHalf + +namespace detail { + +// RVV may only support LMUL >= SEW/64; returns whether that holds for D. Note +// that SEW = sizeof(T)*8 and LMUL = 1 << Pow2(). +template +constexpr bool IsSupportedLMUL(D d) { + return (size_t{1} << (Pow2(d) + 3)) >= sizeof(TFromD); +} + +} // namespace detail + +// If IsSupportedLMUL, just 'truncate' i.e. halve LMUL. +template * = nullptr> +HWY_API VFromD LowerHalf(const DH /* tag */, const VFromD> v) { + return detail::Trunc(v); +} + +// Otherwise, there is no corresponding intrinsic type (e.g. vuint64mf2_t), and +// the hardware may set "vill" if we attempt such an LMUL. However, the V +// extension on application processors requires Zvl128b, i.e. VLEN >= 128, so it +// still makes sense to have half of an SEW=64 vector. We instead just return +// the vector, and rely on the kPow2 in DH to halve the return value of Lanes(). +template * = nullptr> +HWY_API V LowerHalf(const DH /* tag */, const V v) { + return v; +} + +// Same, but without D arg +template +HWY_API VFromD>> LowerHalf(const V v) { + return LowerHalf(Half>(), v); +} + +template +HWY_API VFromD UpperHalf(const DH d2, const VFromD> v) { + return LowerHalf(d2, detail::SlideDown(v, v, Lanes(d2))); +} + +// ================================================== SWIZZLE + +namespace detail { +// Special instruction for 1 lane is presumably faster? +#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Up, slide1up_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Up, fslide1up_vf, _ALL) +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Down, slide1down_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Down, fslide1down_vf, _ALL) +#undef HWY_RVV_SLIDE1 +} // namespace detail + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f, _ALL) +#undef HWY_RVV_GET_LANE + +// ------------------------------ ExtractLane +template +HWY_API TFromV ExtractLane(const V v, size_t i) { + return GetLane(detail::SlideDown(v, v, i)); +} + +// ------------------------------ InsertLane + +template +HWY_API V InsertLane(const V v, size_t i, TFromV t) { + const DFromV d; + const RebindToUnsigned du; // Iota0 is unsigned only + using TU = TFromD; + const auto is_i = detail::EqS(detail::Iota0(du), static_cast(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +namespace detail { +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetOnlyFirst, sof) +} // namespace detail + +// For 8-bit lanes, Iota0 might overflow. +template +HWY_API V InsertLane(const V v, size_t i, TFromV t) { + const DFromV d; + const auto zero = Zero(d); + const auto one = Set(d, 1); + const auto ge_i = Eq(detail::SlideUp(zero, one, i), one); + const auto is_i = detail::SetOnlyFirst(ge_i); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ OddEven +template +HWY_API V OddEven(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + const auto is_even = detail::EqS(detail::AndS(detail::Iota0(du), 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ DupEven (OddEven) +template +HWY_API V DupEven(const V v) { + const V up = detail::Slide1Up(v); + return OddEven(up, v); +} + +// ------------------------------ DupOdd (OddEven) +template +HWY_API V DupOdd(const V v) { + const V down = detail::Slide1Down(v); + return OddEven(v, down); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV)); + const auto idx_block = ShiftRight(detail::Iota0(du)); + const auto is_even = detail::EqS(detail::AndS(idx_block, 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; + const size_t lpb = detail::LanesPerBlock(d); + const V down = detail::SlideDown(v, v, lpb); + const V up = detail::SlideUp(v, v, lpb); + return OddEvenBlocks(up, down); +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + static_assert(sizeof(TFromD) == sizeof(TFromV), "Index != lane"); + const RebindToUnsigned du; // instead of : avoids unused d. + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllTrue(du, detail::LtS(indices, Lanes(d)))); +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +// <32bit are not part of Highway API, but used in Broadcast. This limits VLMAX +// to 2048! We could instead use vrgatherei16. +#define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, idx, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) +#undef HWY_RVV_TABLE + +// ------------------------------ ConcatOdd (TableLookupLanes) +template +HWY_API V ConcatOdd(D d, const V hi, const V lo) { + const RebindToUnsigned du; // Iota0 is unsigned only + const auto iota = detail::Iota0(du); + const auto idx = detail::AddS(Add(iota, iota), 1); + const auto lo_odd = TableLookupLanes(lo, idx); + const auto hi_odd = TableLookupLanes(hi, idx); + return detail::SlideUp(lo_odd, hi_odd, Lanes(d) / 2); +} + +// ------------------------------ ConcatEven (TableLookupLanes) +template +HWY_API V ConcatEven(D d, const V hi, const V lo) { + const RebindToUnsigned du; // Iota0 is unsigned only + const auto iota = detail::Iota0(du); + const auto idx = Add(iota, iota); + const auto lo_even = TableLookupLanes(lo, idx); + const auto hi_even = TableLookupLanes(hi, idx); + return detail::SlideUp(lo_even, hi_even, Lanes(d) / 2); +} + +// ------------------------------ Reverse (TableLookupLanes) +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(du); + const auto idx = + detail::ReverseSubS(detail::Iota0(du), static_cast(N - 1)); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse2 (RotateRight, OddEven) + +// Shifting and adding requires fewer instructions than blending, but casting to +// u32 only works for LMUL in [1/2, 8]. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} +// For LMUL < 1/2, we can extend and then truncate. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Twice d2; + const Twice d4; + const Repartition du32; + const auto vx = detail::Ext(d4, detail::Ext(d2, v)); + const auto rx = BitCast(d4, RotateRight<16>(BitCast(du32, vx))); + return detail::Trunc(detail::Trunc(rx)); +} + +// Shifting and adding requires fewer instructions than blending, but casting to +// u64 does not work for LMUL < 1. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Repartition du64; + return BitCast(d, RotateRight<32>(BitCast(du64, v))); +} + +// For fractions, we can extend and then truncate. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Twice d2; + const Twice d4; + const Repartition du64; + const auto vx = detail::Ext(d4, detail::Ext(d2, v)); + const auto rx = BitCast(d4, RotateRight<32>(BitCast(du64, vx))); + return detail::Trunc(detail::Trunc(rx)); +} + +template , HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API V Reverse2(D /* tag */, const V v) { + const V up = detail::Slide1Up(v); + const V down = detail::Slide1Down(v); + return OddEven(up, down); +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 3); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 7); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { + const Repartition du64; + const size_t N = Lanes(du64); + const auto rev = + detail::ReverseSubS(detail::Iota0(du64), static_cast(N - 1)); + // Swap lo/hi u64 within each block + const auto idx = detail::XorS(rev, 1); + return BitCast(d, TableLookupLanes(BitCast(du64, v), idx)); +} + +// ------------------------------ Compress + +// RVV supports all lane types natively. +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template +struct CompressIsPartition { + enum { value = 0 }; +}; + +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ + return v##OP##_vm_##CHAR##SEW##LMUL(v, v, mask, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_COMPRESS, Compress, compress, _ALL) +#undef HWY_RVV_COMPRESS + +// ------------------------------ CompressNot +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +template +HWY_API V CompressBlocksNot(V v, const M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + detail::StoreN(count, Compress(v, mask), d, unaligned); + return count; +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); + const auto hi_up = detail::SlideUp(hi8, hi8, 16 - kBytes); + const auto lo_down = detail::SlideDown(lo8, lo8, kBytes); + const auto is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +} + +// ------------------------------ CombineShiftRightLanes +template > +HWY_API V CombineShiftRightLanes(const D d, const V hi, V lo) { + constexpr size_t kLanesUp = 16 / sizeof(TFromV) - kLanes; + const auto hi_up = detail::SlideUp(hi, hi, kLanesUp); + const auto lo_down = detail::SlideDown(lo, lo, kLanes); + const auto is_lo = detail::FirstNPerBlock(d); + return IfThenElse(is_lo, lo_down, hi_up); +} + +// ------------------------------ Shuffle2301 (ShiftLeft) +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const Repartition du64; + const auto v64 = BitCast(du64, v); + return BitCast(d, Or(ShiftRight<32>(v64), ShiftLeft<32>(v64))); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<3>(d, v, v); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<2>(d, v, v); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ TableLookupBytes + +// Extends or truncates a vector to match the given d. +namespace detail { + +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + const Simd dh; + const Simd dhh; + return Ext(d, Ext(dh, Ext(dhh, v))); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + const Simd dh; + return Ext(d, Ext(dh, v)); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Ext(d, v); +} + +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD v) + -> VFromD { + return v; +} + +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Trunc(v); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Trunc(Trunc(v)); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Trunc(Trunc(Trunc(v))); +} + +} // namespace detail + +template +HWY_API VI TableLookupBytes(const VT vt, const VI vi) { + const DFromV dt; // T=table, I=index. + const DFromV di; + const Repartition dt8; + const Repartition di8; + // Required for producing half-vectors with table lookups from a full vector. + // If we instead run at the LMUL of the index vector, lookups into the table + // would be truncated. Thus we run at the larger of the two LMULs and truncate + // the result vector to the original index LMUL. + constexpr int kPow2T = Pow2(dt8); + constexpr int kPow2I = Pow2(di8); + const Simd dm8; // m=max + const auto vmt = detail::ChangeLMUL(dm8, BitCast(dt8, vt)); + const auto vmi = detail::ChangeLMUL(dm8, BitCast(di8, vi)); + auto offsets = detail::OffsetsOf128BitBlocks(dm8, detail::Iota0(dm8)); + // If the table is shorter, wrap around offsets so they do not reference + // undefined lanes in the newly extended vmt. + if (kPow2T < kPow2I) { + offsets = detail::AndS(offsets, static_cast(Lanes(dt8) - 1)); + } + const auto out = TableLookupLanes(vmt, Add(vmi, offsets)); + return BitCast(di, detail::ChangeLMUL(di8, out)); +} + +template +HWY_API VI TableLookupBytesOr0(const VT vt, const VI idx) { + const DFromV di; + const Repartition di8; + const auto idx8 = BitCast(di8, idx); + const auto lookup = TableLookupBytes(vt, idx8); + return BitCast(di, IfThenZeroElse(detail::LtS(idx8, 0), lookup)); +} + +// ------------------------------ Broadcast +template +HWY_API V Broadcast(const V v) { + const DFromV d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(d)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return TableLookupLanes(v, idx); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(const D d, const V v) { + const RebindToSigned di; + using TI = TFromD; + const auto shifted = detail::SlideUp(v, v, kLanes); + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(di)), + static_cast(detail::LanesPerBlock(di) - 1)); + const auto clear = detail::LtS(idx_mod, static_cast(kLanes)); + return IfThenZeroElse(clear, shifted); +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template >> +HWY_API V ShiftRightLanes(const Simd d, V v) { + const RebindToSigned di; + using TI = TFromD; + // For partial vectors, clear upper lanes so we shift in zeros. + if (N <= 16 / sizeof(T)) { + v = IfThenElseZero(FirstN(d, N), v); + } + + const auto shifted = detail::SlideDown(v, v, kLanes); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + const size_t lpb = detail::LanesPerBlock(di); + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(di)), static_cast(lpb - 1)); + const auto keep = detail::LtS(idx_mod, static_cast(lpb - kLanes)); + return IfThenElseZero(keep, shifted); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const RebindToUnsigned du; + using TU = TFromD; + const auto i = detail::Iota0(du); + const auto idx_mod = ShiftRight<1>( + detail::AndS(i, static_cast(detail::LanesPerBlock(du) - 1))); + const auto idx = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper + +template +HWY_API V InterleaveUpper(const D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const RebindToUnsigned du; + using TU = TFromD; + const size_t lpb = detail::LanesPerBlock(du); + const auto i = detail::Iota0(du); + const auto idx_mod = ShiftRight<1>(detail::AndS(i, static_cast(lpb - 1))); + const auto idx_lower = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto idx = detail::AddS(idx_lower, static_cast(lpb / 2)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} + +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} + +// ------------------------------ ZipUpper +template +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== REDUCE + +// vector = f(vector, zero_m1) +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(D d, HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, m1) v0) { \ + return Set(d, GetLane(v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v0, v, v0, Lanes(d)))); \ + } + +// ------------------------------ SumOfLanes + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL) +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D d, const VFromD v) { + const auto v0 = Zero(ScalableTag>()); // always m1 + return detail::RedSum(d, v, v0); +} + +// ------------------------------ MinOfLanes +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL) +} // namespace detail + +template +HWY_API VFromD MinOfLanes(D d, const VFromD v) { + using T = TFromD; + const ScalableTag d1; // always m1 + const auto neutral = Set(d1, HighestValue()); + return detail::RedMin(d, v, neutral); +} + +// ------------------------------ MaxOfLanes +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL) +} // namespace detail + +template +HWY_API VFromD MaxOfLanes(D d, const VFromD v) { + using T = TFromD; + const ScalableTag d1; // always m1 + const auto neutral = Set(d1, LowestValue()); + return detail::RedMax(d, v, neutral); +} + +#undef HWY_RVV_REDUCE + +// ================================================== Ops with dependencies + +// ------------------------------ PopulationCount (ShiftRight) + +// Handles LMUL >= 2 or capped vectors, which generic_ops-inl cannot. +template , HWY_IF_LANE_SIZE_D(D, 1), + hwy::EnableIf* = nullptr> +HWY_API V PopulationCount(V v) { + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, detail::AndS(ShiftRight<1>(v), 0x55)); + v = Add(detail::AndS(ShiftRight<2>(v), 0x33), detail::AndS(v, 0x33)); + return detail::AndS(Add(v, ShiftRight<4>(v)), 0x0F); +} + +// ------------------------------ LoadDup128 + +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const VFromD loaded = Load(d, p); + // idx must be unsigned for TableLookupLanes. + using TU = MakeUnsigned>; + const TU mask = static_cast(detail::LanesPerBlock(d) - 1); + // Broadcast the first block. + const VFromD> idx = detail::AndS(detail::Iota0(d), mask); + return TableLookupLanes(loaded, idx); +} + +// ------------------------------ LoadMaskBits + +// Support all combinations of T and SHIFT(LMUL) without explicit overloads for +// each. First overload for MLEN=1..64. +namespace detail { + +// Maps D to MLEN (wrapped in SizeTag), such that #mask_bits = VLEN/MLEN. MLEN +// increases with lane size and decreases for increasing LMUL. Cap at 64, the +// largest supported by HWY_RVV_FOREACH_B (and intrinsics), for virtual LMUL +// e.g. vuint16mf8_t: (8*2 << 3) == 128. +template +using MaskTag = hwy::SizeTag), -Pow2(D())))>; + +#define HWY_RVV_LOAD_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_INLINE HWY_RVV_M(MLEN) \ + NAME(hwy::SizeTag /* tag */, const uint8_t* bits, size_t N) { \ + return OP##_v_b##MLEN(bits, N); \ + } +HWY_RVV_FOREACH_B(HWY_RVV_LOAD_MASK_BITS, LoadMaskBits, vlm) +#undef HWY_RVV_LOAD_MASK_BITS +} // namespace detail + +template > +HWY_API auto LoadMaskBits(D d, const uint8_t* bits) + -> decltype(detail::LoadMaskBits(MT(), bits, Lanes(d))) { + return detail::LoadMaskBits(MT(), bits, Lanes(d)); +} + +// ------------------------------ StoreMaskBits +#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(D d, HWY_RVV_M(MLEN) m, uint8_t* bits) { \ + const size_t N = Lanes(d); \ + OP##_v_b##MLEN(bits, m, N); \ + /* Non-full byte, need to clear the undefined upper bits. */ \ + /* Use MaxLanes and sizeof(T) to move some checks to compile-time. */ \ + constexpr bool kLessThan8 = \ + detail::ScaleByPower(16 / sizeof(TFromD), Pow2(d)) < 8; \ + if (MaxLanes(d) < 8 || (kLessThan8 && N < 8)) { \ + const int mask = (1 << N) - 1; \ + bits[0] = static_cast(bits[0] & mask); \ + } \ + return (N + 7) / 8; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, StoreMaskBits, vsm) +#undef HWY_RVV_STORE_MASK_BITS + +// ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits) + +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ FirstN (Iota0, Lt, RebindMask, SlideUp) + +// Disallow for 8-bit because Iota is likely to overflow. +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const RebindToSigned di; + using TI = TFromD; + return RebindMask( + d, detail::LtS(BitCast(di, detail::Iota0(d)), static_cast(n))); +} + +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const auto zero = Zero(d); + const auto one = Set(d, 1); + return Eq(detail::SlideUp(one, zero, n), one); +} + +// ------------------------------ Neg (Sub) + +template +HWY_API V Neg(const V v) { + return detail::ReverseSubS(v, 0); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + +// ------------------------------ Abs (Max, Neg) + +template +HWY_API V Abs(const V v) { + return Max(v, Neg(v)); +} + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx, _ALL) + +#undef HWY_RVV_RETV_ARGV2 + +// ------------------------------ AbsDiff (Abs, Sub) +template +HWY_API V AbsDiff(const V a, const V b) { + return Abs(Sub(a, b)); +} + +// ------------------------------ Round (NearestInt, ConvertTo, CopySign) + +// IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not have +// a dedicated instruction for that. Rounding to integer and converting back to +// float is correct except when the input magnitude is large, in which case the +// input was already an integer (because mantissa >> exponent is zero). + +namespace detail { +enum RoundingModes { kNear, kTrunc, kDown, kUp }; + +template +HWY_INLINE auto UseInt(const V v) -> decltype(MaskFromVec(v)) { + return detail::LtS(Abs(v), MantissaEnd>()); +} + +} // namespace detail + +template +HWY_API V Round(const V v) { + const DFromV df; + + const auto integer = NearestInt(v); // round using current mode + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Trunc (ConvertTo) +template +HWY_API V Trunc(const V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Ceil +template +HWY_API V Ceil(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kUp)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floor +template +HWY_API V Floor(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kDown)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floating-point classification (Ne) + +// vfclass does not help because it would require 3 instructions (to AND and +// then compare the bits), whereas these are just 1-3 integer instructions. + +template +HWY_API MFromD> IsNaN(const V v) { + return Ne(v, v); +} + +template > +HWY_API MFromD IsInf(const V v) { + const D d; + const RebindToSigned di; + using T = TFromD; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqS(Add(vi, vi), hwy::MaxExponentTimes2())); +} + +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + using T = TFromD; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtS(exp, hwy::MaxExponentField())); +} + +// ------------------------------ Iota (ConvertTo) + +template +HWY_API VFromD Iota(const D d, TFromD first) { + return detail::AddS(detail::Iota0(d), first); +} + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToUnsigned du; + return detail::AddS(BitCast(d, detail::Iota0(du)), first); +} + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToUnsigned du; + const RebindToSigned di; + return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), first); +} + +// ------------------------------ MulEven/Odd (Mul, OddEven) + +template , + class DW = RepartitionToWide> +HWY_API VFromD MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo)); +} + +// There is no 64x64 vwmul. +template +HWY_INLINE V MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return OddEven(detail::Slide1Up(hi), lo); +} + +template +HWY_INLINE V MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return OddEven(hi, detail::Slide1Down(lo)); +} + +// ------------------------------ ReorderDemote2To (OddEven, Combine) + +template +HWY_API VFromD> ReorderDemote2To( + Simd dbf16, + VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + const RebindToUnsigned> du32; + const VFromD b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template * = nullptr, + class D32 = RepartitionToWide>> +HWY_API VFromD> ReorderDemote2To( + Simd d16, VFromD a, VFromD b) { + const Twice d32t; + const VFromD ab = Combine(d32t, a, b); + return DemoteTo(d16, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template >>> +HWY_API VFromD> ReorderDemote2To(Simd d16, + V32 a, V32 b) { + const Half d16h; + const VFromD a16 = DemoteTo(d16h, a); + const VFromD b16 = DemoteTo(d16h, b); + return Combine(d16, a16, b16); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +namespace detail { + +// Non-overloaded wrapper function so we can define DF32 in template args. +template < + size_t N, int kPow2, class DF32 = Simd, + class VF32 = VFromD, + class DU16 = RepartitionToNarrow>>> +HWY_API VF32 ReorderWidenMulAccumulateBF16(Simd df32, + VFromD a, VFromD b, + const VF32 sum0, VF32& sum1) { + const RebindToUnsigned du32; + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Using shift/and instead of Zip leads to the odd/even order that + // RearrangeToOddPlusEven prefers. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +#define HWY_RVV_WIDEN_MACC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEWD, LMULD) sum, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return OP##CHAR##SEWD##LMULD(sum, a, b, Lanes(d)); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_WIDEN_MACC, WidenMulAcc, vwmacc_vv_, _EXT_VIRT) +#undef HWY_RVV_WIDEN_MACC + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template * = nullptr, + class D32 = Simd, class V32 = VFromD, + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(Simd d32, + VFromD a, VFromD b, + const V32 sum0, V32& sum1) { + const Twice d32t; + using V32T = VFromD; + V32T sum = Combine(d32t, sum1, sum0); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template , class V32 = VFromD, + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(Simd d32, + VFromD a, VFromD b, + const V32 sum0, V32& sum1) { + const Half d16h; + using V16H = VFromD; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +} // namespace detail + +template +HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, + const VW sum0, VW& sum1) { + return detail::ReorderWidenMulAccumulateBF16(d32, a, b, sum0, sum1); +} + +template +HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, + const VW sum0, VW& sum1) { + return detail::ReorderWidenMulAccumulateI16(d32, a, b, sum0, sum1); +} + +// ------------------------------ RearrangeToOddPlusEven + +template // vint32_t* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // vwmacc doubles LMUL, so we require a pairwise sum here. This op is + // expected to be less frequent than ReorderWidenMulAccumulate, hence it's + // preferable to do the extra work here rather than do manual odd/even + // extraction there. + const DFromV di32; + const RebindToUnsigned du32; + const Twice di32x2; + const RepartitionToWide di64x2; + const RebindToUnsigned du64x2; + const auto combined = BitCast(di64x2, Combine(di32x2, sum1, sum0)); + // Isolate odd/even int32 in int64 lanes. + const auto even = ShiftRight<32>(ShiftLeft<32>(combined)); // sign extend + const auto odd = ShiftRight<32>(combined); + return BitCast(di32, TruncateTo(du32, BitCast(du64x2, Add(even, odd)))); +} + +// For max LMUL, we cannot Combine again and instead manually unroll. +HWY_API vint32m8_t RearrangeToOddPlusEven(vint32m8_t sum0, vint32m8_t sum1) { + const DFromV d; + const Half dh; + const vint32m4_t lo = + RearrangeToOddPlusEven(LowerHalf(sum0), UpperHalf(dh, sum0)); + const vint32m4_t hi = + RearrangeToOddPlusEven(LowerHalf(sum1), UpperHalf(dh, sum1)); + return Combine(d, hi, lo); +} + +template // vfloat* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); // invariant already holds +} + +// ------------------------------ Lt128 +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Shift leftward so L can influence H. + const VFromD ltLx = detail::Slide1Up(ltHL); + const VFromD vecHx = OrAnd(ltHL, eqHL, ltLx); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); +} + +// ------------------------------ Lt128Upper +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(ltHL, detail::Slide1Down(ltHL))); +} + +// ------------------------------ Eq128 +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD eqLH = Reverse2(d, eqHL); + return MaskFromVec(And(eqHL, eqLH)); +} + +// ------------------------------ Eq128Upper +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(eqHL, detail::Slide1Down(eqHL))); +} + +// ------------------------------ Ne128 +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + const VFromD neLH = Reverse2(d, neHL); + return MaskFromVec(Or(neHL, neLH)); +} + +// ------------------------------ Ne128Upper +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(neHL, detail::Slide1Down(neHL))); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE VFromD Min128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD minHL = Min(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, a, b); + // The upper lane is just minHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(minHL, IfThenElse(eqXH, minHL, lo)); +} + +template +HWY_INLINE VFromD Max128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD maxHL = Max(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, b, a); + // The upper lane is just maxHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(maxHL, IfThenElse(eqXH, maxHL, lo)); +} + +template +HWY_INLINE VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_RVV_AVL +#undef HWY_RVV_D +#undef HWY_RVV_FOREACH +#undef HWY_RVV_FOREACH_08_ALL +#undef HWY_RVV_FOREACH_08_ALL_VIRT +#undef HWY_RVV_FOREACH_08_DEMOTE +#undef HWY_RVV_FOREACH_08_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_08_EXT +#undef HWY_RVV_FOREACH_08_EXT_VIRT +#undef HWY_RVV_FOREACH_08_TRUNC +#undef HWY_RVV_FOREACH_08_VIRT +#undef HWY_RVV_FOREACH_16_ALL +#undef HWY_RVV_FOREACH_16_ALL_VIRT +#undef HWY_RVV_FOREACH_16_DEMOTE +#undef HWY_RVV_FOREACH_16_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_16_EXT +#undef HWY_RVV_FOREACH_16_EXT_VIRT +#undef HWY_RVV_FOREACH_16_TRUNC +#undef HWY_RVV_FOREACH_16_VIRT +#undef HWY_RVV_FOREACH_32_ALL +#undef HWY_RVV_FOREACH_32_ALL_VIRT +#undef HWY_RVV_FOREACH_32_DEMOTE +#undef HWY_RVV_FOREACH_32_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_32_EXT +#undef HWY_RVV_FOREACH_32_EXT_VIRT +#undef HWY_RVV_FOREACH_32_TRUNC +#undef HWY_RVV_FOREACH_32_VIRT +#undef HWY_RVV_FOREACH_64_ALL +#undef HWY_RVV_FOREACH_64_ALL_VIRT +#undef HWY_RVV_FOREACH_64_DEMOTE +#undef HWY_RVV_FOREACH_64_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_64_EXT +#undef HWY_RVV_FOREACH_64_EXT_VIRT +#undef HWY_RVV_FOREACH_64_TRUNC +#undef HWY_RVV_FOREACH_64_VIRT +#undef HWY_RVV_FOREACH_B +#undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F16 +#undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F3264 +#undef HWY_RVV_FOREACH_F64 +#undef HWY_RVV_FOREACH_I +#undef HWY_RVV_FOREACH_I08 +#undef HWY_RVV_FOREACH_I16 +#undef HWY_RVV_FOREACH_I163264 +#undef HWY_RVV_FOREACH_I32 +#undef HWY_RVV_FOREACH_I64 +#undef HWY_RVV_FOREACH_U +#undef HWY_RVV_FOREACH_U08 +#undef HWY_RVV_FOREACH_U16 +#undef HWY_RVV_FOREACH_U163264 +#undef HWY_RVV_FOREACH_U32 +#undef HWY_RVV_FOREACH_U64 +#undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI08 +#undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI163264 +#undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI3264 +#undef HWY_RVV_FOREACH_UI64 +#undef HWY_RVV_M +#undef HWY_RVV_RETM_ARGM +#undef HWY_RVV_RETV_ARGV +#undef HWY_RVV_RETV_ARGVS +#undef HWY_RVV_RETV_ARGVV +#undef HWY_RVV_T +#undef HWY_RVV_V +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/scalar-inl.h b/third_party/highway/hwy/ops/scalar-inl.h new file mode 100644 index 0000000000..c28f7b510f --- /dev/null +++ b/third_party/highway/hwy/ops/scalar-inl.h @@ -0,0 +1,1626 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template +using Sisd = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec1 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 1; // only for DFromV + + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template +class Mask1 { + using Raw = hwy::MakeUnsigned; + + public: + static HWY_INLINE Mask1 FromBool(bool b) { + Mask1 mask; + mask.bits = b ? static_cast(~Raw{0}) : 0; + return mask; + } + + Raw bits; +}; + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template +HWY_API Vec1 BitCast(Sisd /* tag */, Vec1 v) { + static_assert(sizeof(T) <= sizeof(FromT), "Promoting is undefined"); + T to; + CopyBytes(&v.raw, &to); // not same size - ok to shrink + return Vec1(to); +} + +// ------------------------------ Set + +template +HWY_API Vec1 Zero(Sisd /* tag */) { + return Vec1(T(0)); +} + +template +HWY_API Vec1 Set(Sisd /* tag */, const T2 t) { + return Vec1(static_cast(t)); +} + +template +HWY_API Vec1 Undefined(Sisd d) { + return Zero(d); +} + +template +HWY_API Vec1 Iota(const Sisd /* tag */, const T2 first) { + return Vec1(static_cast(first)); +} + +template +using VFromD = decltype(Zero(D())); + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec1 Not(const Vec1 v) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, v).raw))); +} + +// ------------------------------ And + +template +HWY_API Vec1 And(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator&(const Vec1 a, const Vec1 b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template +HWY_API Vec1 AndNot(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, a).raw & + BitCast(du, b).raw))); +} + +// ------------------------------ Or + +template +HWY_API Vec1 Or(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator|(const Vec1 a, const Vec1 b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template +HWY_API Vec1 Xor(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 + +template +HWY_API Vec1 Xor3(Vec1 x1, Vec1 x2, Vec1 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec1 Or3(Vec1 o1, Vec1 o2, Vec1 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ CopySign + +template +HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Sisd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Sisd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec1 BroadcastSignBit(const Vec1 v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + return v.raw < 0 ? Vec1(T(-1)) : Vec1(0); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template +HWY_API Vec1 PopulationCount(Vec1 v) { + return Vec1(static_cast(PopCount(v.raw))); +} + +// ------------------------------ Mask + +template +HWY_API Mask1 RebindMask(Sisd /*tag*/, Mask1 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1{m.bits}; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask1 MaskFromVec(const Vec1 v) { + Mask1 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +Vec1 VecFromMask(const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +Vec1 VecFromMask(Sisd /* tag */, const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +HWY_API Mask1 FirstN(Sisd /*tag*/, size_t n) { + return Mask1::FromBool(n != 0); +} + +// Returns mask ? yes : no. +template +HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, + const Vec1 no) { + return mask.bits ? yes : no; +} + +template +HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { + return mask.bits ? yes : Vec1(0); +} + +template +HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { + return mask.bits ? Vec1(0) : no; +} + +template +HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { + return v.raw < 0 ? yes : no; +} + +template +HWY_API Vec1 ZeroIfNegative(const Vec1 v) { + return v.raw < 0 ? Vec1(0) : v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask1 Not(const Mask1 m) { + return MaskFromVec(Not(VecFromMask(Sisd(), m))); +} + +template +HWY_API Mask1 And(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Or(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 ExclusiveNeither(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeft(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1( + static_cast(static_cast>(v.raw) << kBits)); +} + +template +HWY_API Vec1 ShiftRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(static_cast(v.raw >> kBits)); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = static_cast(BitCast(du, v).raw >> kBits); + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast(sign << sign_shift); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { // T is unsigned + return Vec1(static_cast(v.raw >> kBits)); + } +#endif +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template +struct RotateRight { + template + HWY_INLINE Vec1 operator()(const Vec1 v) const { + return Or(ShiftRight(v), ShiftLeft(v)); + } +}; + +template <> +struct RotateRight<0> { + template + HWY_INLINE Vec1 operator()(const Vec1 v) const { + return v; + } +}; + +} // namespace detail + +template +HWY_API Vec1 RotateRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight()(v); +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { + return Vec1( + static_cast(static_cast>(v.raw) << bits)); +} + +template +HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(static_cast(v.raw >> bits)); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = static_cast(BitCast(du, v).raw >> bits); + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast(sign << sign_shift); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { // T is unsigned + return Vec1(static_cast(v.raw >> bits)); + } +#endif +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template +HWY_API Vec1 operator<<(const Vec1 v, const Vec1 bits) { + return ShiftLeftSame(v, static_cast(bits.raw)); +} + +template +HWY_API Vec1 operator>>(const Vec1 v, const Vec1 bits) { + return ShiftRightSame(v, static_cast(bits.raw)); +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec1 operator+(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} + +template +HWY_API Vec1 operator-(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +HWY_API Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} +HWY_API Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} + +// ------------------------------ Absolute value + +template +HWY_API Vec1 Abs(const Vec1 a) { + const T i = a.raw; + if (i >= 0 || i == hwy::LimitsMin()) return a; + return Vec1(static_cast(-i & T{-1})); +} +HWY_API Vec1 Abs(Vec1 a) { + int32_t i; + CopyBytes(&a.raw, &i); + i &= 0x7FFFFFFF; + CopyBytes(&i, &a.raw); + return a; +} +HWY_API Vec1 Abs(Vec1 a) { + int64_t i; + CopyBytes(&a.raw, &i); + i &= 0x7FFFFFFFFFFFFFFFL; + CopyBytes(&i, &a.raw); + return a; +} + +// ------------------------------ Min/Max + +// may be unavailable, so implement our own. +namespace detail { + +static inline float Abs(float f) { + uint32_t i; + CopyBytes<4>(&f, &i); + i &= 0x7FFFFFFFu; + CopyBytes<4>(&i, &f); + return f; +} +static inline double Abs(double f) { + uint64_t i; + CopyBytes<8>(&f, &i); + i &= 0x7FFFFFFFFFFFFFFFull; + CopyBytes<8>(&i, &f); + return f; +} + +static inline bool SignBit(float f) { + uint32_t i; + CopyBytes<4>(&f, &i); + return (i >> 31) != 0; +} +static inline bool SignBit(double f) { + uint64_t i; + CopyBytes<8>(&f, &i); + return (i >> 63) != 0; +} + +} // namespace detail + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + if (isnan(a.raw)) return b; + if (isnan(b.raw)) return a; + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + if (isnan(a.raw)) return b; + if (isnan(b.raw)) return a; + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Xor(v, SignBit(Sisd())); +} + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Zero(Sisd()) - v; +} + +// ------------------------------ mul/div + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(double{a.raw} * b.raw)); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(static_cast(a.raw) * + static_cast(b.raw))); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(static_cast(a.raw) * + static_cast(b.raw))); +} + +template +HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { + return Vec1(a.raw / b.raw); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((a.raw * b.raw) >> 16)); +} +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the result + // is the same but this way it is also defined. + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> 16)); +} + +HWY_API Vec1 MulFixedPoint15(Vec1 a, Vec1 b) { + return Vec1(static_cast((2 * a.raw * b.raw + 32768) >> 16)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { + const int64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} +HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { + const uint64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} + +// Approximate reciprocal +HWY_API Vec1 ApproximateReciprocal(const Vec1 v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1(0.0f); + return Vec1(1.0f / v.raw); +} + +// Absolute value of difference. +HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { + return mul * x + add; +} + +template +HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return add - mul * x; +} + +template +HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { + return mul * x - sub; +} + +template +HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_API Vec1 ApproximateReciprocalSqrt(const Vec1 v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopySameSize(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &f); + // One Newton-Raphson iteration + return Vec1(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_API Vec1 Sqrt(const Vec1 v) { +#if HWY_COMPILER_GCC && defined(HWY_NO_LIBCXX) + return Vec1(__builtin_sqrt(v.raw)); +#else + return Vec1(sqrtf(v.raw)); +#endif +} +HWY_API Vec1 Sqrt(const Vec1 v) { +#if HWY_COMPILER_GCC && defined(HWY_NO_LIBCXX) + return Vec1(__builtin_sqrt(v.raw)); +#else + return Vec1(sqrt(v.raw)); +#endif +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec1 Round(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN + return v; + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(0), v); + // Round to even + if ((rounded & 1) && detail::Abs(static_cast(rounded) - v.raw) == T(0.5)) { + return Vec1(static_cast(rounded - (v.raw < T(0) ? -1 : 1))); + } + return Vec1(static_cast(rounded)); +} + +// Round-to-nearest even. +HWY_API Vec1 NearestInt(const Vec1 v) { + using T = float; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool is_sign = detail::SignBit(v.raw); + + if (!(abs < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= static_cast(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + return Vec1(static_cast(v.raw)); + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return Vec1(0); + // Round to even + if ((rounded & 1) && detail::Abs(static_cast(rounded) - v.raw) == T(0.5)) { + return Vec1(rounded - (is_sign ? -1 : 1)); + } + return Vec1(rounded); +} + +template +HWY_API Vec1 Trunc(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN + return v; + } + const TI truncated = static_cast(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1(0), v); + return Vec1(static_cast(truncated)); +} + +template +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +template +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} + +// Toward -infinity, aka floor +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} + +// ================================================== COMPARE + +template +HWY_API Mask1 operator==(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw == b.raw); +} + +template +HWY_API Mask1 operator!=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw != b.raw); +} + +template +HWY_API Mask1 TestBit(const Vec1 v, const Vec1 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask1 operator<(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw < b.raw); +} +template +HWY_API Mask1 operator>(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw > b.raw); +} + +template +HWY_API Mask1 operator<=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw <= b.raw); +} +template +HWY_API Mask1 operator>=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw >= b.raw); +} + +// ------------------------------ Floating-point classification (==) + +template +HWY_API Mask1 IsNaN(const Vec1 v) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned bits; + CopySameSize(&v, &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + return Mask1::FromBool(bits > ExponentMask()); +} + +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFF000000u)); +} +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFFE0000000000000ull)); +} + +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFF000000u); +} +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFFE0000000000000ull); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec1 Load(Sisd /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopySameSize(aligned, &t); + return Vec1(t); +} + +template +HWY_API Vec1 MaskedLoad(Mask1 m, Sisd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API Vec1 LoadU(Sisd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API Vec1 LoadDup128(Sisd d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec1 v, Sisd /* tag */, + T* HWY_RESTRICT aligned) { + CopySameSize(&v.raw, aligned); +} + +template +HWY_API void StoreU(const Vec1 v, Sisd d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +template +HWY_API void BlendedStore(const Vec1 v, Mask1 m, Sisd d, + T* HWY_RESTRICT p) { + if (!m.bits) return; + StoreU(v, d, p); +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +template +HWY_API void LoadInterleaved3(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1, Vec1& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +template +HWY_API void LoadInterleaved4(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1, Vec1& v2, + Vec1& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(const Vec1 v0, const Vec1 v1, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); +} + +template +HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, + const Vec1 v2, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +template +HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, + const Vec1 v2, const Vec1 v3, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template +HWY_API void Stream(const Vec1 v, Sisd d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template +HWY_API void ScatterOffset(Vec1 v, Sisd d, T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + uint8_t* const base8 = reinterpret_cast(base) + offset.raw; + return Store(v, d, reinterpret_cast(base8)); +} + +template +HWY_API void ScatterIndex(Vec1 v, Sisd d, T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +template +HWY_API Vec1 GatherOffset(Sisd d, const T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + return Load(d, reinterpret_cast(addr)); +} + +template +HWY_API Vec1 GatherIndex(Sisd d, const T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Load(d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template +HWY_API Vec1 PromoteTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1(static_cast(from.raw)); +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting float to narrower integer/float + if (IsInf(from).bits || + Abs(from).raw > static_cast(HighestValue())) { + return Vec1(detail::SignBit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (IsInf(from).bits || + Abs(from).raw > static_cast(HighestValue())) { + return Vec1(detail::SignBit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + static_assert(!IsFloat(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw = HWY_MIN(HWY_MAX(LimitsMin(), from.raw), LimitsMax()); + return Vec1(static_cast(from.raw)); +} + +HWY_API Vec1 PromoteTo(Sisd /* tag */, const Vec1 v) { + uint16_t bits16; + CopySameSize(&v.raw, &bits16); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return Vec1(sign ? -subnormal : subnormal); + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + float out; + CopySameSize(&bits32, &out); + return Vec1(out); +} + +HWY_API Vec1 PromoteTo(Sisd d, const Vec1 v) { + return Set(d, F32FromBF16(v.raw)); +} + +HWY_API Vec1 DemoteTo(Sisd /* tag */, + const Vec1 v) { + uint32_t bits32; + CopySameSize(&v.raw, &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + Vec1 out; + if (exp < -24) { + const uint16_t zero = 0; + CopySameSize(&zero, &out.raw); + return out; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + CopySameSize(&narrowed, &out.raw); + return out; +} + +HWY_API Vec1 DemoteTo(Sisd d, const Vec1 v) { + return Set(d, BF16FromF32(v.raw)); +} + +template +HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax in FromT, so use double. + const double f = static_cast(from.raw); + if (IsInf(from).bits || + Abs(Vec1(f)).raw > static_cast(LimitsMax())) { + return Vec1(detail::SignBit(from.raw) ? LimitsMin() + : LimitsMax()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // int## -> float##: no check needed + return Vec1(static_cast(from.raw)); +} + +HWY_API Vec1 U8FromU32(const Vec1 v) { + return DemoteTo(Sisd(), v); +} + +// ------------------------------ Truncations + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFFFFFFu)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +// ================================================== COMBINE +// UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. + +template +HWY_API Vec1 LowerHalf(Vec1 v) { + return v; +} + +template +HWY_API Vec1 LowerHalf(Sisd /* tag */, Vec1 v) { + return v; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec1 v) { + return v.raw; +} + +template +HWY_API T ExtractLane(const Vec1 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return v.raw; +} + +template +HWY_API Vec1 InsertLane(Vec1 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + v.raw = t; + return v; +} + +template +HWY_API Vec1 DupEven(Vec1 v) { + return v; +} +// DupOdd is unsupported. + +template +HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { + return even; +} + +template +HWY_API Vec1 OddEvenBlocks(Vec1 /* odd */, Vec1 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec1 SwapAdjacentBlocks(Vec1 v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices1 { + MakeSigned raw; +}; + +template +HWY_API Indices1 IndicesFromVec(Sisd, Vec1 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + HWY_DASSERT(vec.raw == 0); + return Indices1{vec.raw}; +} + +template +HWY_API Indices1 SetTableIndices(Sisd d, const TI* idx) { + return IndicesFromVec(d, LoadU(Sisd(), idx)); +} + +template +HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec1 ReverseBlocks(Sisd /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ Reverse + +template +HWY_API Vec1 Reverse(Sisd /* tag */, const Vec1 v) { + return v; +} + +// Must not be called: +template +HWY_API Vec1 Reverse2(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse4(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse8(Sisd /* tag */, const Vec1 v) { + return v; +} + +// ================================================== BLOCKWISE +// Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec1 Broadcast(const Vec1 v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +template +HWY_API Vec1 TableLookupBytesOr0(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +// ------------------------------ ZipLower + +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((uint32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint64_t{b.raw} << 32) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((int32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1((int32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1((int64_t{b.raw} << 32) + a.raw); +} + +template , class VW = Vec1> +HWY_API VW ZipLower(Sisd /* tag */, Vec1 a, Vec1 b) { + return VW(static_cast((TW{b.raw} << (sizeof(T) * 8)) + a.raw)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0; +} + +template +HWY_API bool AllTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits != 0; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask1 LoadMaskBits(Sisd /* tag */, + const uint8_t* HWY_RESTRICT bits) { + return Mask1::FromBool((bits[0] & 1) != 0); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Sisd d, const Mask1 mask, uint8_t* bits) { + *bits = AllTrue(d, mask); + return 1; +} + +template +HWY_API size_t CountTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0 ? 0 : 1; +} + +template +HWY_API intptr_t FindFirstTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template +HWY_API size_t FindKnownFirstTrue(Sisd /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +// ------------------------------ Compress, CompressBits + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +template +HWY_API Vec1 CompressNot(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT unaligned) { + if (!mask.bits) return 0; + StoreU(v, d, unaligned); + return 1; +} + +// ------------------------------ CompressBits +template +HWY_API Vec1 CompressBits(Vec1 v, const uint8_t* HWY_RESTRICT /*bits*/) { + return v; +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(Vec1 v, const uint8_t* HWY_RESTRICT bits, + Sisd d, T* HWY_RESTRICT unaligned) { + const Mask1 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return MulAdd(Vec1(F32FromBF16(a.raw)), + Vec1(F32FromBF16(b.raw)), sum0); +} + +HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(a.raw * b.raw + sum0.raw); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec1 RearrangeToOddPlusEven(const Vec1 sum0, + Vec1 /* sum1 */) { + return sum0; // invariant already holds +} + +// ================================================== REDUCTIONS + +// Sum of all lanes, i.e. the only one. +template +HWY_API Vec1 SumOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} +template +HWY_API Vec1 MinOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} +template +HWY_API Vec1 MaxOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/set_macros-inl.h b/third_party/highway/hwy/ops/set_macros-inl.h new file mode 100644 index 0000000000..051dbb3348 --- /dev/null +++ b/third_party/highway/hwy/ops/set_macros-inl.h @@ -0,0 +1,444 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Sets macros based on HWY_TARGET. + +// This include guard is toggled by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. +#if defined(HWY_SET_MACROS_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif + +#endif // HWY_SET_MACROS_PER_TARGET + +#include "hwy/detect_targets.h" + +#undef HWY_NAMESPACE +#undef HWY_ALIGN +#undef HWY_MAX_BYTES +#undef HWY_LANES + +#undef HWY_HAVE_SCALABLE +#undef HWY_HAVE_INTEGER64 +#undef HWY_HAVE_FLOAT16 +#undef HWY_HAVE_FLOAT64 +#undef HWY_MEM_OPS_MIGHT_FAULT +#undef HWY_NATIVE_FMA +#undef HWY_CAP_GE256 +#undef HWY_CAP_GE512 + +#undef HWY_TARGET_STR + +#if defined(HWY_DISABLE_PCLMUL_AES) +#define HWY_TARGET_STR_PCLMUL_AES "" +#else +#define HWY_TARGET_STR_PCLMUL_AES ",pclmul,aes" +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) +#define HWY_TARGET_STR_BMI2_FMA "" +#else +#define HWY_TARGET_STR_BMI2_FMA ",bmi,bmi2,fma" +#endif + +#if defined(HWY_DISABLE_F16C) +#define HWY_TARGET_STR_F16C "" +#else +#define HWY_TARGET_STR_F16C ",f16c" +#endif + +#define HWY_TARGET_STR_SSSE3 "sse2,ssse3" + +#define HWY_TARGET_STR_SSE4 \ + HWY_TARGET_STR_SSSE3 ",sse4.1,sse4.2" HWY_TARGET_STR_PCLMUL_AES +// Include previous targets, which are the half-vectors of the next target. +#define HWY_TARGET_STR_AVX2 \ + HWY_TARGET_STR_SSE4 ",avx,avx2" HWY_TARGET_STR_BMI2_FMA HWY_TARGET_STR_F16C +#define HWY_TARGET_STR_AVX3 \ + HWY_TARGET_STR_AVX2 ",avx512f,avx512vl,avx512dq,avx512bw" + +// Before include guard so we redefine HWY_TARGET_STR on each include, +// governed by the current HWY_TARGET. + +//----------------------------------------------------------------------------- +// SSSE3 +#if HWY_TARGET == HWY_SSSE3 + +#define HWY_NAMESPACE N_SSSE3 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSSE3 + +//----------------------------------------------------------------------------- +// SSE4 +#elif HWY_TARGET == HWY_SSE4 + +#define HWY_NAMESPACE N_SSE4 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE4 + +//----------------------------------------------------------------------------- +// AVX2 +#elif HWY_TARGET == HWY_AVX2 + +#define HWY_NAMESPACE N_AVX2 +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#ifdef HWY_DISABLE_BMI2_FMA +#define HWY_NATIVE_FMA 0 +#else +#define HWY_NATIVE_FMA 1 +#endif + +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_AVX2 + +//----------------------------------------------------------------------------- +// AVX3[_DL] +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL + +#define HWY_ALIGN alignas(64) +#define HWY_MAX_BYTES 64 +#define HWY_LANES(T) (64 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 1 + +#if HWY_TARGET == HWY_AVX3 + +#define HWY_NAMESPACE N_AVX3 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3 + +#elif HWY_TARGET == HWY_AVX3_DL + +#define HWY_NAMESPACE N_AVX3_DL +#define HWY_TARGET_STR \ + HWY_TARGET_STR_AVX3 \ + ",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avxvnni,avx512bitalg," \ + "avx512vpopcntdq" + +#else +#error "Logic error" +#endif // HWY_TARGET == HWY_AVX3_DL + +//----------------------------------------------------------------------------- +// PPC8 +#elif HWY_TARGET == HWY_PPC8 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_PPC8 + +#define HWY_TARGET_STR "altivec,vsx" + +//----------------------------------------------------------------------------- +// NEON +#elif HWY_TARGET == HWY_NEON + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 + +#if HWY_ARCH_ARM_A64 +#define HWY_HAVE_FLOAT64 1 +#else +#define HWY_HAVE_FLOAT64 0 +#endif + +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#define HWY_NATIVE_FMA 1 +#else +#define HWY_NATIVE_FMA 0 +#endif + +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_NEON + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_ARCH_ARM_V7 +#define HWY_TARGET_STR "+neon-vfpv4" +#else +#define HWY_TARGET_STR "+crypto" +#endif // HWY_ARCH_ARM_V7 +#else +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// SVE[2] +#elif HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE || \ + HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + +// SVE only requires lane alignment, not natural alignment of the entire vector. +#define HWY_ALIGN alignas(8) + +// Value ensures MaxLanes() is the tightest possible upper bound to reduce +// overallocation. +#define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_SVE2 +#define HWY_NAMESPACE N_SVE2 +#define HWY_MAX_BYTES 256 +#elif HWY_TARGET == HWY_SVE_256 +#define HWY_NAMESPACE N_SVE_256 +#define HWY_MAX_BYTES 32 +#elif HWY_TARGET == HWY_SVE2_128 +#define HWY_NAMESPACE N_SVE2_128 +#define HWY_MAX_BYTES 16 +#else +#define HWY_NAMESPACE N_SVE +#define HWY_MAX_BYTES 256 +#endif + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +#define HWY_TARGET_STR "+sve2-aes" +#else +#define HWY_TARGET_STR "+sve" +#endif +#else +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// WASM +#elif HWY_TARGET == HWY_WASM + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// WASM_EMU256 +#elif HWY_TARGET == HWY_WASM_EMU256 + +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM_EMU256 + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// RVV +#elif HWY_TARGET == HWY_RVV + +// RVV only requires lane alignment, not natural alignment of the entire vector, +// and the compiler already aligns builtin types, so nothing to do here. +#define HWY_ALIGN + +// The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8). +#define HWY_MAX_BYTES 65536 + +// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual +// LMUL. This is the tightest possible upper bound. +#define HWY_LANES(T) (8192 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if defined(__riscv_zvfh) +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#define HWY_NAMESPACE N_RVV + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. +// (rv64gcv is not a valid target) + +//----------------------------------------------------------------------------- +// EMU128 +#elif HWY_TARGET == HWY_EMU128 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_EMU128 + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// SCALAR +#elif HWY_TARGET == HWY_SCALAR + +#define HWY_ALIGN +#define HWY_MAX_BYTES 8 +#define HWY_LANES(T) 1 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_SCALAR + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +// Override this to 1 in asan/msan builds, which will still fault. +#if HWY_IS_ASAN || HWY_IS_MSAN +#undef HWY_MEM_OPS_MIGHT_FAULT +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#endif + +// Clang <9 requires this be invoked at file scope, before any namespace. +#undef HWY_BEFORE_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_BEFORE_NAMESPACE() \ + HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_BEFORE_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +// Clang <9 requires any namespaces be closed before this macro. +#undef HWY_AFTER_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_AFTER_NAMESPACE() \ + HWY_POP_ATTRIBUTES \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_AFTER_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +#undef HWY_ATTR +#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) +#define HWY_ATTR __attribute__((target(HWY_TARGET_STR))) +#else +#define HWY_ATTR +#endif diff --git a/third_party/highway/hwy/ops/shared-inl.h b/third_party/highway/hwy/ops/shared-inl.h new file mode 100644 index 0000000000..02246bfa4f --- /dev/null +++ b/third_party/highway/hwy/ops/shared-inl.h @@ -0,0 +1,332 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target definitions shared by ops/*.h and user code. + +// We are covered by the highway.h include guard, but generic_ops-inl.h +// includes this again #if HWY_IDE. +#if defined(HIGHWAY_HWY_OPS_SHARED_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#undef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_OPS_SHARED_TOGGLE +#endif + +#ifndef HWY_NO_LIBCXX +#include +#endif + +#include "hwy/base.h" + +// Separate header because foreach_target.h re-enables its include guard. +#include "hwy/ops/set_macros-inl.h" + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Highway operations are implemented as overloaded functions selected using an +// internal-only tag type D := Simd. T is the lane type. kPow2 is a +// shift count applied to scalable vectors. Instead of referring to Simd<> +// directly, users create D via aliases ScalableTag() (defaults to a +// full vector, or fractions/groups if the argument is negative/positive), +// CappedTag or FixedTag. The actual number of lanes is +// Lanes(D()), a power of two. For scalable vectors, N is either HWY_LANES or a +// cap. For constexpr-size vectors, N is the actual number of lanes. This +// ensures Half> is the same type as Full256, as required by x86. +template +struct Simd { + constexpr Simd() = default; + using T = Lane; + static_assert((N & (N - 1)) == 0 && N != 0, "N must be a power of two"); + + // Only for use by MaxLanes, required by MSVC. Cannot be enum because GCC + // warns when using enums and non-enums in the same expression. Cannot be + // static constexpr function (another MSVC limitation). + static constexpr size_t kPrivateN = N; + static constexpr int kPrivatePow2 = kPow2; + + template + static constexpr size_t NewN() { + // Round up to correctly handle scalars with N=1. + return (N * sizeof(T) + sizeof(NewT) - 1) / sizeof(NewT); + } + +#if HWY_HAVE_SCALABLE + template + static constexpr int Pow2Ratio() { + return (sizeof(NewT) > sizeof(T)) + ? static_cast(CeilLog2(sizeof(NewT) / sizeof(T))) + : -static_cast(CeilLog2(sizeof(T) / sizeof(NewT))); + } +#endif + + // Widening/narrowing ops change the number of lanes and/or their type. + // To initialize such vectors, we need the corresponding tag types: + +// PromoteTo/DemoteTo() with another lane type, but same number of lanes. +#if HWY_HAVE_SCALABLE + template + using Rebind = Simd()>; +#else + template + using Rebind = Simd; +#endif + + // Change lane type while keeping the same vector size, e.g. for MulEven. + template + using Repartition = Simd(), kPow2>; + +// Half the lanes while keeping the same lane type, e.g. for LowerHalf. +// Round up to correctly handle scalars with N=1. +#if HWY_HAVE_SCALABLE + // Reducing the cap (N) is required for SVE - if N is the limiter for f32xN, + // then we expect Half> to have N/2 lanes (rounded up). + using Half = Simd; +#else + using Half = Simd; +#endif + +// Twice the lanes while keeping the same lane type, e.g. for Combine. +#if HWY_HAVE_SCALABLE + using Twice = Simd; +#else + using Twice = Simd; +#endif +}; + +namespace detail { + +template +constexpr bool IsFull(Simd /* d */) { + return N == HWY_LANES(T) && kPow2 == 0; +} + +// Returns the number of lanes (possibly zero) after applying a shift: +// - 0: no change; +// - [1,3]: a group of 2,4,8 [fractional] vectors; +// - [-3,-1]: a fraction of a vector from 1/8 to 1/2. +constexpr size_t ScaleByPower(size_t N, int pow2) { +#if HWY_TARGET == HWY_RVV + return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); +#else + return pow2 >= 0 ? N : (N >> (-pow2)); +#endif +} + +// Struct wrappers enable validation of arguments via static_assert. +template +struct ScalableTagChecker { + static_assert(-3 <= kPow2 && kPow2 <= 3, "Fraction must be 1/8 to 8"); +#if HWY_TARGET == HWY_RVV + // Only RVV supports register groups. + using type = Simd; +#elif HWY_HAVE_SCALABLE + // For SVE[2], only allow full or fractions. + using type = Simd; +#elif HWY_TARGET == HWY_SCALAR + using type = Simd; +#else + // Only allow full or fractions. + using type = Simd; +#endif +}; + +template +struct CappedTagChecker { + static_assert(kLimit != 0, "Does not make sense to have zero lanes"); + // Safely handle non-power-of-two inputs by rounding down, which is allowed by + // CappedTag. Otherwise, Simd would static_assert. + static constexpr size_t kLimitPow2 = size_t{1} << hwy::FloorLog2(kLimit); + using type = Simd; +}; + +template +struct FixedTagChecker { + static_assert(kNumLanes != 0, "Does not make sense to have zero lanes"); + static_assert(kNumLanes <= HWY_LANES(T), "Too many lanes"); + using type = Simd; +}; + +} // namespace detail + +// Alias for a tag describing a full vector (kPow2 == 0: the most common usage, +// e.g. 1D loops where the application does not care about the vector size) or a +// fraction/multiple of one. Multiples are the same as full vectors for all +// targets except RVV. Fractions (kPow2 < 0) are useful as the argument/return +// value of type promotion and demotion. +template +using ScalableTag = typename detail::ScalableTagChecker::type; + +// Alias for a tag describing a vector with *up to* kLimit active lanes, even on +// targets with scalable vectors and HWY_SCALAR. The runtime lane count +// `Lanes(tag)` may be less than kLimit, and is 1 on HWY_SCALAR. This alias is +// typically used for 1D loops with a relatively low application-defined upper +// bound, e.g. for 8x8 DCTs. However, it is better if data structures are +// designed to be vector-length-agnostic (e.g. a hybrid SoA where there are +// chunks of `M >= MaxLanes(d)` DC components followed by M AC1, .., and M AC63; +// this would enable vector-length-agnostic loops using ScalableTag). +template +using CappedTag = typename detail::CappedTagChecker::type; + +// Alias for a tag describing a vector with *exactly* kNumLanes active lanes, +// even on targets with scalable vectors. Requires `kNumLanes` to be a power of +// two not exceeding `HWY_LANES(T)`. +// +// NOTE: if the application does not need to support HWY_SCALAR (+), use this +// instead of CappedTag to emphasize that there will be exactly kNumLanes lanes. +// This is useful for data structures that rely on exactly 128-bit SIMD, but +// these are discouraged because they cannot benefit from wider vectors. +// Instead, applications would ideally define a larger problem size and loop +// over it with the (unknown size) vectors from ScalableTag. +// +// + e.g. if the baseline is known to support SIMD, or the application requires +// ops such as TableLookupBytes not supported by HWY_SCALAR. +template +using FixedTag = typename detail::FixedTagChecker::type; + +template +using TFromD = typename D::T; + +// Tag for the same number of lanes as D, but with the LaneType T. +template +using Rebind = typename D::template Rebind; + +template +using RebindToSigned = Rebind>, D>; +template +using RebindToUnsigned = Rebind>, D>; +template +using RebindToFloat = Rebind>, D>; + +// Tag for the same total size as D, but with the LaneType T. +template +using Repartition = typename D::template Repartition; + +template +using RepartitionToWide = Repartition>, D>; +template +using RepartitionToNarrow = Repartition>, D>; + +// Tag for the same lane type as D, but half the lanes. +template +using Half = typename D::Half; + +// Tag for the same lane type as D, but twice the lanes. +template +using Twice = typename D::Twice; + +template +using Full16 = Simd; + +template +using Full32 = Simd; + +template +using Full64 = Simd; + +template +using Full128 = Simd; + +// Same as base.h macros but with a Simd argument instead of T. +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(TFromD) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(TFromD) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(TFromD) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(TFromD) +#define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_LANE_SIZE(TFromD, bytes) +#define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_LANE_SIZE(TFromD, bytes) +#define HWY_IF_LANE_SIZE_ONE_OF_D(D, bit_array) \ + HWY_IF_LANE_SIZE_ONE_OF(TFromD, bit_array) + +// MSVC workaround: use PrivateN directly instead of MaxLanes. +#define HWY_IF_LT128_D(D) \ + hwy::EnableIf) < 16>* = nullptr +#define HWY_IF_GE128_D(D) \ + hwy::EnableIf) >= 16>* = nullptr + +// Same, but with a vector argument. ops/*-inl.h define their own TFromV. +#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV) +#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV) +#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV) +#define HWY_IF_LANE_SIZE_V(V, bytes) HWY_IF_LANE_SIZE(TFromV, bytes) +#define HWY_IF_NOT_LANE_SIZE_V(V, bytes) HWY_IF_NOT_LANE_SIZE(TFromV, bytes) +#define HWY_IF_LANE_SIZE_ONE_OF_V(V, bit_array) \ + HWY_IF_LANE_SIZE_ONE_OF(TFromV, bit_array) + +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr int Pow2(D /* d */) { + return D::kPrivatePow2; +} + +// MSVC requires the explicit . +#define HWY_IF_POW2_GE(D, MIN) hwy::EnableIf(D()) >= (MIN)>* = nullptr + +#if HWY_HAVE_SCALABLE + +// Upper bound on the number of lanes. Intended for template arguments and +// reducing code size (e.g. for SSE4, we know at compile-time that vectors will +// not exceed 16 bytes). WARNING: this may be a loose bound, use Lanes() as the +// actual size for allocating storage. WARNING: MSVC might not be able to deduce +// arguments if this is used in EnableIf. See HWY_IF_LT128_D above. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return detail::ScaleByPower(HWY_MIN(D::kPrivateN, HWY_LANES(TFromD)), + D::kPrivatePow2); +} + +#else +// Workaround for MSVC 2017: T,N,kPow2 argument deduction fails, so returning N +// is not an option, nor does a member function work. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return D::kPrivateN; +} + +// (Potentially) non-constant actual size of the vector at runtime, subject to +// the limit imposed by the Simd. Useful for advancing loop counters. +// Targets with scalable vectors define this themselves. +template +HWY_INLINE HWY_MAYBE_UNUSED size_t Lanes(Simd) { + return N; +} + +#endif // !HWY_HAVE_SCALABLE + +// NOTE: GCC generates incorrect code for vector arguments to non-inlined +// functions in two situations: +// - on Windows and GCC 10.3, passing by value crashes due to unaligned loads: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412. +// - on ARM64 and GCC 9.3.0 or 11.2.1, passing by value causes many (but not +// all) tests to fail. +// +// We therefore pass by const& only on GCC and (Windows or ARM64). This alias +// must be used for all vector/mask parameters of functions marked HWY_NOINLINE, +// and possibly also other functions that are not inlined. +#if HWY_COMPILER_GCC_ACTUAL && (HWY_OS_WIN || HWY_ARCH_ARM_A64) +template +using VecArg = const V&; +#else +template +using VecArg = V; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_OPS_SHARED_TOGGLE diff --git a/third_party/highway/hwy/ops/wasm_128-inl.h b/third_party/highway/hwy/ops/wasm_128-inl.h new file mode 100644 index 0000000000..095fd4f1f0 --- /dev/null +++ b/third_party/highway/hwy/ops/wasm_128-inl.h @@ -0,0 +1,4591 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit WASM vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +#ifdef HWY_WASM_OLD_NAMES +#define wasm_i8x16_shuffle wasm_v8x16_shuffle +#define wasm_i16x8_shuffle wasm_v16x8_shuffle +#define wasm_i32x4_shuffle wasm_v32x4_shuffle +#define wasm_i64x2_shuffle wasm_v64x2_shuffle +#define wasm_u16x8_extend_low_u8x16 wasm_i16x8_widen_low_u8x16 +#define wasm_u32x4_extend_low_u16x8 wasm_i32x4_widen_low_u16x8 +#define wasm_i32x4_extend_low_i16x8 wasm_i32x4_widen_low_i16x8 +#define wasm_i16x8_extend_low_i8x16 wasm_i16x8_widen_low_i8x16 +#define wasm_u32x4_extend_high_u16x8 wasm_i32x4_widen_high_u16x8 +#define wasm_i32x4_extend_high_i16x8 wasm_i32x4_widen_high_i16x8 +#define wasm_i32x4_trunc_sat_f32x4 wasm_i32x4_trunc_saturate_f32x4 +#define wasm_u8x16_add_sat wasm_u8x16_add_saturate +#define wasm_u8x16_sub_sat wasm_u8x16_sub_saturate +#define wasm_u16x8_add_sat wasm_u16x8_add_saturate +#define wasm_u16x8_sub_sat wasm_u16x8_sub_saturate +#define wasm_i8x16_add_sat wasm_i8x16_add_saturate +#define wasm_i8x16_sub_sat wasm_i8x16_sub_saturate +#define wasm_i16x8_add_sat wasm_i16x8_add_saturate +#define wasm_i16x8_sub_sat wasm_i16x8_sub_saturate +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET == HWY_WASM_EMU256 +template +using Full256 = Simd; +#endif + +namespace detail { + +template +struct Raw128 { + using type = __v128_u; +}; +template <> +struct Raw128 { + using type = __f32x4; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::type raw; +}; + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __v128_u BitCastToInteger(__v128_u v) { return v; } +HWY_INLINE __v128_u BitCastToInteger(__f32x4 v) { + return static_cast<__v128_u>(v); +} +HWY_INLINE __v128_u BitCastToInteger(__f64x2 v) { + return static_cast<__v128_u>(v); +} + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __v128_u operator()(__v128_u v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f32x4 operator()(__v128_u v) { return static_cast<__f32x4>(v); } +}; + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{wasm_i32x4_splat(0)}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{wasm_f32x4_splat(0.0f)}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{wasm_i8x16_splat(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { + return Vec128{wasm_i16x8_splat(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { + return Vec128{wasm_i32x4_splat(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { + return Vec128{wasm_i64x2_splat(static_cast(t))}; +} + +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{wasm_i8x16_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{wasm_i16x8_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{wasm_i32x4_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { + return Vec128{wasm_i64x2_splat(t)}; +} + +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{wasm_f32x4_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_add_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add_sat(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_sub_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub_sat(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_avgr(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i8x16_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i16x8_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i32x4_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i64x2_abs(v.raw)}; +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f32x4_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u64x2_shr(v.raw, kBits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i64x2_shr(v.raw, kBits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +// After https://reviews.llvm.org/D108415 shift argument became unsigned. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unsigned +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u64x2_shr(v.raw, bits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shr(v.raw, bits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ignore Wsign-conversion +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t min[2] = {HWY_MIN(a0, b0), HWY_MIN(a1, b1)}; + return Vec128{wasm_v128_load(min)}; +} + +// Signed +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + alignas(16) int64_t min[4]; + min[0] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(min)}; +} + +// Float +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Equivalent to a < b ? a : b (taking into account our swapped arg order, + // so that Min(NaN, x) is x to match x86). + return Vec128{wasm_f32x4_pmin(b.raw, a.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t max[2] = {HWY_MAX(a0, b0), HWY_MAX(a1, b1)}; + return Vec128{wasm_v128_load(max)}; +} + +// Signed +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + alignas(16) int64_t max[2]; + max[0] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(max)}; +} + +// Float +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Equivalent to b < a ? a : b (taking into account our swapped arg order, + // so that Max(NaN, x) is x to match x86). + return Vec128{wasm_f32x4_pmax(b.raw, a.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u32x4_extmul_low_u16x8(a.raw, b.raw); + const auto h = wasm_u32x4_extmul_high_u16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i32x4_extmul_low_i16x8(a.raw, b.raw); + const auto h = wasm_i32x4_extmul_high_i16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_q15mulr_sat(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i8x16_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i16x8_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i32x4_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_mul(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_div(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + return one / v; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return mul * x + add; +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return add - mul * x; +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return mul * x - sub; +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f32x4_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{wasm_f32x4_nearest(v.raw)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{wasm_f32x4_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{wasm_f32x4_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{wasm_f32x4_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_eq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// Unsigned +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ne(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper halves are not equal, this is the answer. + const auto m_gt = a32 > b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt.raw, m_gt.raw, 0, 0, 2, 2); + const auto lo_gt = And(m_eq, MaskFromVec(VFromD{lo_in_hi})); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128{wasm_i32x4_shuffle(gt.raw, gt.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float <= >= +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_le(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ge(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec128 Not(Vec128 v) { + return Vec128{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Xor3 + +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(DFromV()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(DFromV()), sign)); +} + +// ------------------------------ BroadcastSignBit (compare) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(d, v < Zero(d)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(Simd /* tag */, Mask128 v) { + return Vec128{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + const DFromV d; + const auto zero = Zero(d); + return IfThenElse(Mask128{(v > zero).raw}, v, zero); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + alignas(16) T lanes[2]; + alignas(16) T bits_lanes[2]; + Store(v, d, lanes); + Store(bits, d, bits_lanes); + lanes[0] <<= bits_lanes[0]; + lanes[1] <<= bits_lanes[1]; + return Load(d, lanes); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{wasm_v128_load(aligned)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// Partial load. +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + Vec128 v; + CopyBytes(p, &v); + return v; +} + +// LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// Partial store. +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); +} + +HWY_API void Store(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT p) { + *p = wasm_f32x4_extract_lane(v.raw, 0); +} + +// StoreU == Store. +template +HWY_API void StoreU(Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i8x16_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i16x8_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i32x4_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i64x2_extract_lane(v.raw, kLane)); +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + return wasm_f32x4_extract_lane(v.raw, kLane); +} + +} // namespace detail + +// One overload per vector length just in case *_extract_lane raise compile +// errors if their argument is out of bounds (even if that would never be +// reached at runtime). +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ GetLane +template +HWY_API T GetLane(const Vec128 v) { + return detail::ExtractLane<0>(v); +} + +// ------------------------------ InsertLane + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i8x16_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i16x8_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i32x4_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i64x2_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{wasm_f32x4_replace_lane(v.raw, kLane, t)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + return Vec128{wasm_f64x2_replace_lane(v.raw, kLane, t)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls wasm_f64x2_replace_lane. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return LowerHalf(Simd(), v); +} + +// ------------------------------ ShiftLeftBytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return Vec128{wasm_i8x16_shuffle( + v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return Vec128{wasm_i8x16_shuffle( + v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, + 6)}; + + case 10: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, + 5)}; + + case 11: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, + 4)}; + + case 12: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, + 2, 3)}; + + case 13: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1, 2)}; + + case 14: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0, 1)}; + + case 15: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 0)}; + } + return Vec128{zero}; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(Simd(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +namespace detail { + +// Helper function allows zeroing invalid lanes in caller. +template +HWY_API __i8x16 ShrBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + + switch (kBytes) { + case 0: + return v.raw; + + case 1: + return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16); + + case 2: + return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16); + + case 3: + return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16); + + case 4: + return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16); + + case 5: + return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16); + + case 6: + return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16); + + case 7: + return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16); + + case 8: + return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 9: + return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 10: + return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 11: + return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 12: + return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 13: + return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 14: + return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 15: + return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + case 16: + return zero; + } +} + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128 vfull{v.raw}; + v = Vec128{IfThenElseZero(FirstN(Full128(), N), vfull).raw}; + } + return Vec128{detail::ShrBytes(v)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128{upper.raw}; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full128 /* tag */, V hi, V lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16)}; + + case 2: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21)}; + + case 7: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22)}; + + case 8: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23)}; + + case 9: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24)}; + + case 10: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25)}; + + case 11: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26)}; + + case 12: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27)}; + + case 13: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28)}; + + case 14: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29)}; + + case 15: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30)}; + } + return hi; +} + +template > +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full128 d_full8; + using V8 = VFromD; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, kLane, kLane)}; +} + +// ------------------------------ TableLookupBytes + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { +// Not yet available in all engines, see +// https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md +// V8 implementation of this had a bug, fixed on 2021-04-03: +// https://chromium-review.googlesource.com/c/v8/v8/+/2822951 +#if 0 + return Vec128{wasm_i8x16_swizzle(bytes.raw, from.raw)}; +#else + alignas(16) uint8_t control[16]; + alignas(16) uint8_t input[16]; + alignas(16) uint8_t output[16]; + wasm_v128_store(control, from.raw); + wasm_v128_store(input, bytes.raw); + for (size_t i = 0; i < 16; ++i) { + output[i] = control[i] < 16 ? input[control[i]] : 0; + } + return Vec128{wasm_v128_load(output)}; +#endif +} + +template +HWY_API Vec128 TableLookupBytesOr0(const Vec128 bytes, + const Vec128 from) { + const Simd d; + // Mask size must match vector type, so cast everything to this type. + Repartition di8; + Repartition> d_bytes8; + const auto msb = BitCast(di8, from) < Zero(di8); + const auto lookup = + TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. +namespace detail { + +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 1, 0, 3 + 16, 2 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 1, 0, 3 + 8, 2 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 1, 0, 3 + 4, 2 + 4)}; +} + +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 0, 3, 2 + 16, 1 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 0, 3, 2 + 8, 1 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 3, 2 + 4, 1 + 4)}; +} + +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 2, 1, 0 + 16, 3 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 2, 1, 0 + 8, 3 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 1, 0 + 4, 3 + 4)}; +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + __v128_u raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#endif + + const Repartition d8; + using V8 = VFromD; + const Repartition d16; + + // Broadcast each lane index to all bytes of T and shift to bytes + static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); + if (sizeof(T) == 4) { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; + } else { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; + } +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + using TI = MakeSigned; + const DFromV d; + const Rebind di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return Vec128{Shuffle2301(Vec128{v.raw}).raw}; +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + return BitCast(d, Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 3, 2, + 1, 0, 7, 6, 5, 4)}); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { + return Reverse(d, v); +} + +template +HWY_API Vec128 Reverse8(Simd, const Vec128) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +// Additional overload for the optional tag. +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +} // namespace detail + +// Full +template > +HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, + Vec128 lo_half) { + const Half d2; + const RebindToUnsigned du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); +} + +// ------------------------------ ConcatLowerLower + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(Full128 /* tag */, const Vec128 hi, + const Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} +template +HWY_API Vec128 ConcatLowerLower(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatUpperUpper + +template +HWY_API Vec128 ConcatUpperUpper(Full128 /* tag */, const Vec128 hi, + const Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} +template +HWY_API Vec128 ConcatUpperUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +// ------------------------------ ConcatLowerUpper + +template +HWY_API Vec128 ConcatLowerUpper(Full128 d, const Vec128 hi, + const Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} +template +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +// ------------------------------ ConcatUpperLower +template +HWY_API Vec128 ConcatUpperLower(Simd d, const Vec128 hi, + const Vec128 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} + +// 8-bit x8 +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 17, 19, 21, + 23, 1, 3, 5, 7, 17, 19, 21, 23)}; +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper 3/4. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 17, 19, 1, 3, 17, + 19, 1, 3, 17, 19, 1, 3, 17, 19)}; +} + +// 16-bit full +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +// 16-bit x4 +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 9, 11, 1, 3, 9, 11)}; +} + +// 32-bit full +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30)}; +} + +// 8-bit x8 +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 16, 18, 20, + 22, 0, 2, 4, 6, 16, 18, 20, 22)}; +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper 3/4. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 16, 18, 0, 2, 16, + 18, 0, 2, 16, 18, 0, 2, 16, 18)}; +} + +// 16-bit full +template +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14)}; +} + +// 16-bit x4 +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 8, 10, 0, 2, 8, 10)}; +} + +// 32-bit full +template +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 0, 0, 2, 2)}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<2> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<4> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<8> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; +} + +} // namespace detail + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +template +HWY_API Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_extend_low_i8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_extend_low_i16x8(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i64x2_extend_low_i32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f64x2_convert_low_i32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* di */, + const Vec128 v) { + return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128{DemoteTo(du16, bits16).raw}; +} + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + const auto u16 = OddEven(BitCast(du16, a), BitCast(du16, b_in_even)); + return BitCast(dbf16, u16); +} + +// Specializations for partial vectors because i16x8_narrow_i32x4 sets lanes +// above 2*N. +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Full128 /*d16*/, + Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec16 TruncateTo(Full16 /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + const auto v4 = ConcatEven(d, v2, v2); + return LowerHalf(LowerHalf(LowerHalf(ConcatEven(d, v4, v4)))); +} + +HWY_API Vec32 TruncateTo(Full32 /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + return LowerHalf(LowerHalf(ConcatEven(d, v2, v2))); +} + +HWY_API Vec64 TruncateTo(Full64 /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + const auto v3 = ConcatEven(d, v2, v2); + return Vec128{v3.raw}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return Vec128{v2.raw}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return Vec128{v2.raw}; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f32x4_convert_i32x4(v.raw)}; +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f32x4_convert_u32x4(v.raw)}; +} +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return ConvertTo(Simd(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = And(BitCast(du16, v), Set(du16, 0xFF)); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128 vbits{wasm_i32x4_splat(static_cast(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Full +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return (hi + lo); +} + +// 64-bit +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return (static_cast(wasm_i64x2_extract_lane(mask.raw, 0)) * + kMagic) >> + 56; +} + +// 32-bit or less: need masking +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + uint64_t bytes = static_cast(wasm_i64x2_extract_lane(mask.raw, 0)); + // Clear potentially undefined bytes. + bytes &= (1ULL << (N * 8)) - 1; + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return (bytes * kMagic) >> 56; +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const __i16x8 zero = wasm_i16x8_splat(0); + const Mask128 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return BitsFromMask(hwy::SizeTag<1>(), mask8); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1] | lanes[2] | lanes[3]; +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const __i64x2 mask_i = static_cast<__i64x2>(mask.raw); + const __i64x2 slice = wasm_i64x2_make(1, 2); + const __i64x2 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1]; +} + +// Returns the lowest N bits for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) == 16) ? bits : bits & ((1ull << N) - 1); +} + +// Returns 0xFF for bytes with index >= N, otherwise 0. +template +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + alignas(16) int64_t lanes[2]; + wasm_v128_store(lanes, m.raw); + return static_cast(-(lanes[0] + lanes[1])); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 m) { + return detail::CountTrue(hwy::SizeTag(), m); +} + +// Partial vector +template +HWY_API size_t CountTrue(const Simd d, const Mask128 m) { + // Ensure all undefined bytes are 0. + const Mask128 mask{detail::BytesAbove()}; + return CountTrue(d, Mask128{AndNot(mask, m).raw}); +} + +// Full vector +template +HWY_API bool AllFalse(const Full128 d, const Mask128 m) { +#if 0 + // Casting followed by wasm_i8x16_any_true results in wasm error: + // i32.eqz[0] expected type i32, found i8x16.popcnt of type s128 + const auto v8 = BitCast(Full128(), VecFromMask(d, m)); + return !wasm_i8x16_any_true(v8.raw); +#else + (void)d; + return (wasm_i64x2_extract_lane(m.raw, 0) | + wasm_i64x2_extract_lane(m.raw, 1)) == 0; +#endif +} + +// Full vector +namespace detail { +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + return wasm_i8x16_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + return wasm_i16x8_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + return wasm_i32x4_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + return wasm_i64x2_all_true(m.raw); +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 m) { + return detail::AllTrue(hwy::SizeTag(), m); +} + +// Partial vectors + +template +HWY_API bool AllFalse(Simd /* tag */, const Mask128 m) { + // Ensure all undefined bytes are 0. + const Mask128 mask{detail::BytesAbove()}; + return AllFalse(Full128(), Mask128{AndNot(mask, m).raw}); +} + +template +HWY_API bool AllTrue(const Simd /* d */, const Mask128 m) { + // Ensure all undefined bytes are FF. + const Mask128 mask{detail::BytesAbove()}; + return AllTrue(Full128(), Mask128{Or(mask, m).raw}); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return bits ? static_cast(Num0BitsBelowLS1Bit_Nonzero64(bits)) : -1; +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromNotBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_WASM_EMU256 + enum { value = 0 }; +#else + enum { value = (sizeof(T) != 1) }; +#endif +}; + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNot(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + using TU = TFromD; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Vec128 compressed = detail::Compress(BitCast(du, v), mask_bits); + const Mask128 store_mask = RebindMask(d, FirstN(du, count)); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ MulEven/Odd (Load) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = + Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = + Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + const Rebind du32; + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Using shift/and instead of Zip leads to the odd/even order that + // RearrangeToOddPlusEven prefers. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is +// safe. +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd /*d32*/, Vec128 a, + Vec128 b, const Vec128 sum0, + Vec128& /*sum1*/) { + return sum0 + Vec128{wasm_i32x4_dot_i16x8(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven( + const Vec128 sum0, const Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + const Vec128 sum1) { + return Add(sum0, sum1); +} + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Vec128{Shuffle2301(Vec128{v10.raw}).raw}; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Vec128{Shuffle2301(Vec128{v10.raw}).raw}); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Vec128{Shuffle2301(Vec128{v10.raw}).raw}); +} + +// N=4 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128 eqHL = Eq(a, b); + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const Vec128 ltLx = DupEven(ltHL); + const Vec128 outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE Mask128 Lt128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE Mask128 Eq128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE Mask128 Eq128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE Mask128 Ne128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE Mask128 Ne128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/wasm_256-inl.h b/third_party/highway/hwy/ops/wasm_256-inl.h new file mode 100644 index 0000000000..aa62f05e00 --- /dev/null +++ b/third_party/highway/hwy/ops/wasm_256-inl.h @@ -0,0 +1,2003 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit WASM vectors and operations. Experimental. +// External include guard in highway.h - see comment there. + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/wasm_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +class Vec256 { + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Vec128 v0; + Vec128 v1; +}; + +template +struct Mask256 { + Mask128 m0; + Mask128 m1; +}; + +// ------------------------------ BitCast + +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = BitCast(dh, v.v0); + ret.v1 = BitCast(dh, v.v1); + return ret; +} + +// ------------------------------ Zero + +template +HWY_API Vec256 Zero(Full256 d) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Zero(dh); + return ret; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec256 Set(Full256 d, const T2 t) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Set(dh, static_cast(t)); + return ret; +} + +template +HWY_API Vec256 Undefined(Full256 d) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Undefined(dh); + return ret; +} + +template +Vec256 Iota(const Full256 d, const T2 first) { + const Half dh; + Vec256 ret; + ret.v0 = Iota(dh, first); + // NB: for floating types the gap between parts might be a bit uneven. + ret.v1 = Iota(dh, AddWithWraparound(hwy::IsFloatTag(), + static_cast(first), Lanes(dh))); + return ret; +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec256 operator+(Vec256 a, const Vec256 b) { + a.v0 += b.v0; + a.v1 += b.v1; + return a; +} + +template +HWY_API Vec256 operator-(Vec256 a, const Vec256 b) { + a.v0 -= b.v0; + a.v1 -= b.v1; + return a; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +template +HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { + a.v0 = SaturatedAdd(a.v0, b.v0); + a.v1 = SaturatedAdd(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { + a.v0 = SaturatedSub(a.v0, b.v0); + a.v1 = SaturatedSub(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { + a.v0 = AverageRound(a.v0, b.v0); + a.v1 = AverageRound(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Abs(Vec256 v) { + v.v0 = Abs(v.v0); + v.v1 = Abs(v.v1); + return v; +} + +// ------------------------------ Shift lanes by constant #bits + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + v.v0 = ShiftLeft(v.v0); + v.v1 = ShiftLeft(v.v1); + return v; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + v.v0 = ShiftRight(v.v0); + v.v1 = ShiftRight(v.v1); + return v; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec256 RotateRight(const Vec256 v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +template +HWY_API Vec256 ShiftLeftSame(Vec256 v, const int bits) { + v.v0 = ShiftLeftSame(v.v0, bits); + v.v1 = ShiftLeftSame(v.v1, bits); + return v; +} + +template +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + v.v0 = ShiftRightSame(v.v0, bits); + v.v1 = ShiftRightSame(v.v1, bits); + return v; +} + +// ------------------------------ Min, Max +template +HWY_API Vec256 Min(Vec256 a, const Vec256 b) { + a.v0 = Min(a.v0, b.v0); + a.v1 = Min(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Max(Vec256 a, const Vec256 b) { + a.v0 = Max(a.v0, b.v0); + a.v1 = Max(a.v1, b.v1); + return a; +} +// ------------------------------ Integer multiplication + +template +HWY_API Vec256 operator*(Vec256 a, const Vec256 b) { + a.v0 *= b.v0; + a.v1 *= b.v1; + return a; +} + +template +HWY_API Vec256 MulHigh(Vec256 a, const Vec256 b) { + a.v0 = MulHigh(a.v0, b.v0); + a.v1 = MulHigh(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 MulFixedPoint15(Vec256 a, const Vec256 b) { + a.v0 = MulFixedPoint15(a.v0, b.v0); + a.v1 = MulFixedPoint15(a.v1, b.v1); + return a; +} + +// Cannot use MakeWide because that returns uint128_t for uint64_t, but we want +// uint64_t. +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} + +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} + +// ------------------------------ Negate +template +HWY_API Vec256 Neg(Vec256 v) { + v.v0 = Neg(v.v0); + v.v1 = Neg(v.v1); + return v; +} + +// ------------------------------ Floating-point division +template +HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { + a.v0 /= b.v0; + a.v1 /= b.v1; + return a; +} + +// Approximate reciprocal +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + const Vec256 one = Set(Full256(), 1.0f); + return one / v; +} + +// Absolute value of difference. +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfma? + return mul * x + add; +} + +// Returns add - mul * x +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { + // TODO(eustas): replace, when implemented in WASM. + return add - mul * x; +} + +// Returns mul * x - sub +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfms? + return mul * x - sub; +} + +// Returns -mul * x - sub +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { + // TODO(eustas): replace, when implemented in WASM. + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec256 Sqrt(Vec256 v) { + v.v0 = Sqrt(v.v0); + v.v1 = Sqrt(v.v1); + return v; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec256 one = Set(Full256(), 1.0f); + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +HWY_API Vec256 Round(Vec256 v) { + v.v0 = Round(v.v0); + v.v1 = Round(v.v1); + return v; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(Vec256 v) { + v.v0 = Trunc(v.v0); + v.v1 = Trunc(v.v1); + return v; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(Vec256 v) { + v.v0 = Ceil(v.v0); + v.v1 = Ceil(v.v1); + return v; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(Vec256 v) { + v.v0 = Floor(v.v0); + v.v1 = Floor(v.v1); + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask256 IsNaN(const Vec256 v) { + return v != v; +} + +template +HWY_API Mask256 IsInf(const Vec256 v) { + const Full256 d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + const Full256 d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256{Mask128{m.m0.raw}, Mask128{m.m1.raw}}; +} + +template +HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask256 operator==(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator==(a.v0, b.v0); + m.m1 = operator==(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator!=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator!=(a.v0, b.v0); + m.m1 = operator!=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<(a.v0, b.v0); + m.m1 = operator<(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>(a.v0, b.v0); + m.m1 = operator>(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<=(a.v0, b.v0); + m.m1 = operator<=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>=(a.v0, b.v0); + m.m1 = operator>=(a.v1, b.v1); + return m; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask256 FirstN(const Full256 d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ================================================== LOGICAL + +template +HWY_API Vec256 Not(Vec256 v) { + v.v0 = Not(v.v0); + v.v1 = Not(v.v1); + return v; +} + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + a.v0 = And(a.v0, b.v0); + a.v1 = And(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + not_mask.v0 = AndNot(not_mask.v0, mask.v0); + not_mask.v1 = AndNot(not_mask.v1, mask.v1); + return not_mask; +} + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + a.v0 = Or(a.v0, b.v0); + a.v1 = Or(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + a.v0 = Xor(a.v0, b.v0); + a.v1 = Xor(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { + return Xor(x1, Xor(x2, x3)); +} + +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { + return Or(o1, Or(o2, o3)); +} + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Full256()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Full256()), sign)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + Mask256 m; + m.m0 = MaskFromVec(v.v0); + m.m1 = MaskFromVec(v.v1); + return m; +} + +template +HWY_API Vec256 VecFromMask(Full256 d, Mask256 m) { + const Half dh; + Vec256 v; + v.v0 = VecFromMask(dh, m.m0); + v.v1 = VecFromMask(dh, m.m1); + return v; +} + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); + yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); + return yes; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); + v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); + return v; +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + return IfThenZeroElse(v < Zero(Full256()), v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { + v.v0 = operator<<(v.v0, bits.v0); + v.v1 = operator<<(v.v1, bits.v1); + return v; +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { + v.v0 = operator>>(v.v0, bits.v0); + v.v1 = operator>>(v.v1, bits.v1); + return v; +} + +// ------------------------------ BroadcastSignBit (compare, VecFromMask) + +template +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight(v); +} +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const Full256 d; + return VecFromMask(d, v < Zero(d)); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec256 Load(Full256 d, const T* HWY_RESTRICT aligned) { + const Half dh; + Vec256 ret; + ret.v0 = Load(dh, aligned); + ret.v1 = Load(dh, aligned + Lanes(dh)); + return ret; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// LoadU == Load. +template +HWY_API Vec256 LoadU(Full256 d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +template +HWY_API Vec256 LoadDup128(Full256 d, const T* HWY_RESTRICT p) { + const Half dh; + Vec256 ret; + ret.v0 = ret.v1 = Load(dh, p); + return ret; +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec256 v, Full256 d, T* HWY_RESTRICT aligned) { + const Half dh; + Store(v.v0, dh, aligned); + Store(v.v1, dh, aligned + Lanes(dh)); +} + +// StoreU == Store. +template +HWY_API void StoreU(Vec256 v, Full256 d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Stream +template +HWY_API void Stream(Vec256 v, Full256 d, T* HWY_RESTRICT aligned) { + // Same as aligned stores. + Store(v, d, aligned); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Full256(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec256 GatherOffset(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256(), offset_lanes); + + alignas(32) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec256 GatherIndex(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) Index index_lanes[N]; + Store(index, Full256(), index_lanes); + + alignas(32) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, Full256(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + Full256 d; + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return v.v0; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return v.v0; +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec256 ShiftLeftBytes(Full256 d, Vec256 v) { + const Half dh; + v.v0 = ShiftLeftBytes(dh, v.v0); + v.v1 = ShiftLeftBytes(dh, v.v1); + return v; +} + +template +HWY_API Vec256 ShiftLeftBytes(Vec256 v) { + return ShiftLeftBytes(Full256(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(Full256(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec256 ShiftRightBytes(Full256 d, Vec256 v) { + const Half dh; + v.v0 = ShiftRightBytes(dh, v.v0); + v.v1 = ShiftRightBytes(dh, v.v1); + return v; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template +HWY_API Vec128 UpperHalf(Full128 /* tag */, const Vec256 v) { + return v.v1; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full256 d, V hi, V lo) { + const Half dh; + hi.v0 = CombineShiftRightBytes(dh, hi.v0, lo.v0); + hi.v1 = CombineShiftRightBytes(dh, hi.v1, lo.v1); + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + Vec256 ret; + ret.v0 = Broadcast(v.v0); + ret.v1 = Broadcast(v.v1); + return ret; +} + +// ------------------------------ TableLookupBytes + +// Both full +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, Vec256 from) { + from.v0 = TableLookupBytes(bytes.v0, from.v0); + from.v1 = TableLookupBytes(bytes.v1, from.v1); + return from; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(const Vec256 bytes, + const Vec128 from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(Full128(), tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(const Vec128 bytes, + const Vec256 from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by wasm_128. + +template +HWY_API VI TableLookupBytesOr0(const V bytes, VI from) { + // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Hard-coded shuffles + +template +HWY_API Vec256 Shuffle01(Vec256 v) { + v.v0 = Shuffle01(v.v0); + v.v1 = Shuffle01(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2301(Vec256 v) { + v.v0 = Shuffle2301(v.v0); + v.v1 = Shuffle2301(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle1032(Vec256 v) { + v.v0 = Shuffle1032(v.v0); + v.v1 = Shuffle1032(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0321(Vec256 v) { + v.v0 = Shuffle0321(v.v0); + v.v1 = Shuffle0321(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2103(Vec256 v) { + v.v0 = Shuffle2103(v.v0); + v.v1 = Shuffle2103(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0123(Vec256 v) { + v.v0 = Shuffle0123(v.v0); + v.v1 = Shuffle0123(v.v1); + return v; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 Shuffle2301(Vec256 a, const Vec256 b) { + a.v0 = Shuffle2301(a.v0, b.v0); + a.v1 = Shuffle2301(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 Shuffle1230(Vec256 a, const Vec256 b) { + a.v0 = Shuffle1230(a.v0, b.v0); + a.v1 = Shuffle1230(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 Shuffle3012(Vec256 a, const Vec256 b) { + a.v0 = Shuffle3012(a.v0, b.v0); + a.v1 = Shuffle3012(a.v1, b.v1); + return a; +} + +} // namespace detail + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices256 { + __v128_u i0; + __v128_u i1; +}; + +template +HWY_API Indices256 IndicesFromVec(Full256 /* tag */, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + Indices256 ret; + ret.i0 = vec.v0.raw; + ret.i1 = vec.v1.raw; + return ret; +} + +template +HWY_API Indices256 SetTableIndices(Full256 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(const Vec256 v, Indices256 idx) { + using TU = MakeUnsigned; + const Full128 dh; + const Full128 duh; + constexpr size_t kLanesPerHalf = 16 / sizeof(TU); + + const Vec128 vi0{idx.i0}; + const Vec128 vi1{idx.i1}; + const Vec128 mask = Set(duh, static_cast(kLanesPerHalf - 1)); + const Vec128 vmod0 = vi0 & mask; + const Vec128 vmod1 = vi1 & mask; + // If ANDing did not change the index, it is for the lower half. + const Mask128 is_lo0 = RebindMask(dh, vi0 == vmod0); + const Mask128 is_lo1 = RebindMask(dh, vi1 == vmod1); + const Indices128 mod0 = IndicesFromVec(dh, vmod0); + const Indices128 mod1 = IndicesFromVec(dh, vmod1); + + Vec256 ret; + ret.v0 = IfThenElse(is_lo0, TableLookupLanes(v.v0, mod0), + TableLookupLanes(v.v1, mod0)); + ret.v1 = IfThenElse(is_lo1, TableLookupLanes(v.v0, mod1), + TableLookupLanes(v.v1, mod1)); + return ret; +} + +template +HWY_API Vec256 TableLookupLanesOr0(Vec256 v, Indices256 idx) { + // The out of bounds behavior will already zero lanes. + return TableLookupLanesOr0(v, idx); +} + +// ------------------------------ Reverse +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order + ret.v0 = Reverse(dh, v.v1); + return ret; +} + +// ------------------------------ Reverse2 +template +HWY_API Vec256 Reverse2(Full256 d, Vec256 v) { + const Half dh; + v.v0 = Reverse2(dh, v.v0); + v.v1 = Reverse2(dh, v.v1); + return v; +} + +// ------------------------------ Reverse4 + +// Each block has only 2 lanes, so swap blocks and their lanes. +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse2(dh, v.v1); // swapped + ret.v1 = Reverse2(dh, v.v0); + return ret; +} + +template +HWY_API Vec256 Reverse4(Full256 d, Vec256 v) { + const Half dh; + v.v0 = Reverse4(dh, v.v0); + v.v1 = Reverse4(dh, v.v1); + return v; +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 /* tag */, Vec256 /* v */) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// Each block has only 4 lanes, so swap blocks and their lanes. +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse4(dh, v.v1); // swapped + ret.v1 = Reverse4(dh, v.v0); + return ret; +} + +template // 1 or 2 bytes +HWY_API Vec256 Reverse8(Full256 d, Vec256 v) { + const Half dh; + v.v0 = Reverse8(dh, v.v0); + v.v1 = Reverse8(dh, v.v1); + return v; +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + a.v0 = InterleaveLower(a.v0, b.v0); + a.v1 = InterleaveLower(a.v1, b.v1); + return a; +} + +// wasm_128 already defines a template with D, V, V args. + +// ------------------------------ InterleaveUpper (UpperHalf) + +template > +HWY_API V InterleaveUpper(Full256 d, V a, V b) { + const Half dh; + a.v0 = InterleaveUpper(dh, a.v0, b.v0); + a.v1 = InterleaveUpper(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(Vec256 a, Vec256 b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) +template +HWY_API Vec256 Combine(Full256 /* d */, Vec128 hi, Vec128 lo) { + Vec256 ret; + ret.v1 = hi; + ret.v0 = lo; + return ret; +} + +// ------------------------------ ZeroExtendVector (Combine) +template +HWY_API Vec256 ZeroExtendVector(Full256 d, Vec128 lo) { + const Half dh; + return Combine(d, Zero(dh), lo); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API Vec256 ConcatLowerLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatUpperLower +template +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatOdd +template +HWY_API Vec256 ConcatOdd(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); + ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ ConcatEven +template +HWY_API Vec256 ConcatEven(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatEven(dh, lo.v1, lo.v0); + ret.v1 = ConcatEven(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ DupEven +template +HWY_API Vec256 DupEven(Vec256 v) { + v.v0 = DupEven(v.v0); + v.v1 = DupEven(v.v1); + return v; +} + +// ------------------------------ DupOdd +template +HWY_API Vec256 DupOdd(Vec256 v) { + v.v0 = DupOdd(v.v0); + v.v1 = DupOdd(v.v1); + return v; +} + +// ------------------------------ OddEven +template +HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { + a.v0 = OddEven(a.v0, b.v0); + a.v1 = OddEven(a.v1, b.v1); + return a; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + odd.v0 = even.v0; + return odd; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + Vec256 ret; + ret.v0 = v.v1; // swapped order + ret.v1 = v.v0; + return ret; +} + +// ------------------------------ ReverseBlocks +template +HWY_API Vec256 ReverseBlocks(Full256 /* tag */, const Vec256 v) { + return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +namespace detail { + +// Unsigned: zero-extend. +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u64x2_extend_high_u32x4(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_high_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_extend_high_i8x16(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_extend_high_i16x8(wasm_i16x8_extend_high_i8x16(v.raw))}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_extend_high_i16x8(v.raw)}; +} +HWY_API Vec128 PromoteUpperTo(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i64x2_extend_high_i32x4(v.raw)}; +} + +HWY_API Vec128 PromoteUpperTo(Full128 dd, + const Vec128 v) { + // There is no wasm_f64x2_convert_high_i32x4. + const Full64 di32h; + return PromoteTo(dd, UpperHalf(di32h, v)); +} + +HWY_API Vec128 PromoteUpperTo(Full128 df32, + const Vec128 v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteUpperTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +HWY_API Vec128 PromoteUpperTo(Full128 df32, + const Vec128 v) { + const Full128 du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +} // namespace detail + +template +HWY_API Vec256 PromoteTo(Full256 d, const Vec128 v) { + const Half dh; + Vec256 ret; + ret.v0 = PromoteTo(dh, LowerHalf(v)); + ret.v1 = detail::PromoteUpperTo(dh, v); + return ret; +} + +// This is the only 4x promotion from 8 to 32-bit. +template +HWY_API Vec256 PromoteTo(Full256 d, const Vec64 v) { + const Half dh; + const Rebind, decltype(d)> d2; // 16-bit lanes + const auto v16 = PromoteTo(d2, v); + Vec256 ret; + ret.v0 = PromoteTo(dh, LowerHalf(v16)); + ret.v1 = detail::PromoteUpperTo(dh, v16); + return ret; +} + +// ------------------------------ DemoteTo + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 di, const Vec256 v) { + const Vec64 lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64 hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +HWY_API Vec128 DemoteTo(Full128 d16, + const Vec256 v) { + const Half d16h; + const Vec64 lo = DemoteTo(d16h, v.v0); + const Vec64 hi = DemoteTo(d16h, v.v1); + return Combine(d16, hi, lo); +} + +HWY_API Vec128 DemoteTo(Full128 dbf16, + const Vec256 v) { + const Half dbf16h; + const Vec64 lo = DemoteTo(dbf16h, v.v0); + const Vec64 hi = DemoteTo(dbf16h, v.v1); + return Combine(dbf16, hi, lo); +} + +// For already range-limited input [0, 255]. +HWY_API Vec64 U8FromU32(const Vec256 v) { + const Full64 du8; + const Full256 di32; // no unsigned DemoteTo + return DemoteTo(du8, BitCast(di32, v)); +} + +// ------------------------------ Truncations + +HWY_API Vec32 TruncateTo(Full32 /* tag */, + const Vec256 v) { + return Vec32{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, + 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, + 24)}; +} + +HWY_API Vec64 TruncateTo(Full64 /* tag */, + const Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, + 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, + 25)}; +} + +HWY_API Vec128 TruncateTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, + 9, 10, 11, 16, 17, 18, 19, 24, 25, + 26, 27)}; +} + +HWY_API Vec64 TruncateTo(Full64 /* tag */, + const Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, + 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, + 28)}; +} + +HWY_API Vec128 TruncateTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, + 9, 12, 13, 16, 17, 20, 21, 24, 25, + 28, 29)}; +} + +HWY_API Vec128 TruncateTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, + 10, 12, 14, 16, 18, 20, 22, 24, 26, + 28, 30)}; +} + +// ------------------------------ ReorderDemote2To +HWY_API Vec256 ReorderDemote2To(Full256 dbf16, + Vec256 a, Vec256 b) { + const RebindToUnsigned du16; + return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); +} + +HWY_API Vec256 ReorderDemote2To(Full256 d16, + Vec256 a, Vec256 b) { + const Half d16h; + Vec256 demoted; + demoted.v0 = DemoteTo(d16h, a); + demoted.v1 = DemoteTo(d16h, b); + return demoted; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template +HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = ConvertTo(dh, v.v0); + ret.v1 = ConvertTo(dh, v.v1); + return ret; +} + +HWY_API Vec256 NearestInt(const Vec256 v) { + return ConvertTo(Full256(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template // 4 or 8 bytes +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + const Half dh; + Mask256 ret; + ret.m0 = LoadMaskBits(dh, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + const uint8_t bits_upper[8] = {static_cast(bits[0] >> kBitsPerHalf)}; + ret.m1 = LoadMaskBits(dh, bits_upper); + return ret; +} + +template // 1 or 2 bytes +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + const Half dh; + Mask256 ret; + ret.m0 = LoadMaskBits(dh, bits); + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); + return ret; +} + +// ------------------------------ Mask + +// `p` points to at least 8 writable bytes. +template // 4 or 8 bytes +HWY_API size_t StoreMaskBits(const Full256 d, const Mask256 mask, + uint8_t* bits) { + const Half dh; + StoreMaskBits(dh, mask.m0, bits); + const uint8_t lo = bits[0]; + StoreMaskBits(dh, mask.m1, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + bits[0] = static_cast(lo | (bits[0] << kBitsPerHalf)); + return (kBitsPerHalf * 2 + 7) / 8; +} + +template // 1 or 2 bytes +HWY_API size_t StoreMaskBits(const Full256 d, const Mask256 mask, + uint8_t* bits) { + const Half dh; + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + StoreMaskBits(dh, mask.m0, bits); + StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); + return kBytesPerHalf * 2; +} + +template +HWY_API size_t CountTrue(const Full256 d, const Mask256 m) { + const Half dh; + return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); +} + +template +HWY_API bool AllFalse(const Full256 d, const Mask256 m) { + const Half dh; + return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); +} + +template +HWY_API bool AllTrue(const Full256 d, const Mask256 m) { + const Half dh; + return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? static_cast(lo) + : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); + const intptr_t hi = FindFirstTrue(dh, mask.m1); + if (lo < 0 && hi < 0) return lo; + constexpr int kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? lo : hi + kLanesPerHalf; +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const Vec256 v, const Mask256 mask, + Full256 d, T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); + const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const Vec256 v, const Mask256 m, + Full256 d, T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); + const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(const Vec256 v, + const uint8_t* HWY_RESTRICT bits, Full256 d, + T* HWY_RESTRICT unaligned) { + const Mask256 m = LoadMaskBits(d, bits); + return CompressStore(v, m, d, unaligned); +} + +// ------------------------------ Compress + +template +HWY_API Vec256 Compress(const Vec256 v, const Mask256 mask) { + const Full256 d; + alignas(32) T lanes[32 / sizeof(T)] = {}; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +// ------------------------------ CompressNot +template +HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + const Full128 dh; + // Because the non-selected (mask=1) blocks are undefined, we can return the + // input unless mask = 01, in which case we must bring down the upper block. + return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; +} + +// ------------------------------ CompressBits + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + const Mask256 m = LoadMaskBits(Full256(), bits); + return Compress(v, m); +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C, + Vec256& D) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + const Vec256 v76 = LoadU(d, unaligned + 3 * N); + + A = ConcatLowerLower(d, v54, v10); + B = ConcatUpperUpper(d, v54, v10); + C = ConcatLowerLower(d, v76, v32); + D = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(const Vec256 i, const Vec256 j, + const Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(const Vec256 i, const Vec256 j, + const Vec256 k, Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(const Vec256 i, const Vec256 j, + const Vec256 k, const Vec256 l, + Full256 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ ReorderWidenMulAccumulate +template +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 d, Vec256 a, + Vec256 b, Vec256 sum0, + Vec256& sum1) { + const Half dh; + sum0.v0 = ReorderWidenMulAccumulate(dh, a.v0, b.v0, sum0.v0, sum1.v0); + sum0.v1 = ReorderWidenMulAccumulate(dh, a.v1, b.v1, sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec256 RearrangeToOddPlusEven(Vec256 sum0, Vec256 sum1) { + sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0); + sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ Reductions + +template +HWY_API Vec256 SumOfLanes(Full256 d, const Vec256 v) { + const Half dh; + const Vec128 lo = SumOfLanes(dh, Add(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +template +HWY_API Vec256 MinOfLanes(Full256 d, const Vec256 v) { + const Half dh; + const Vec128 lo = MinOfLanes(dh, Min(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +template +HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 v) { + const Half dh; + const Vec128 lo = MaxOfLanes(dh, Max(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask256 Lt128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128(dh, a.v0, b.v0); + ret.m1 = Lt128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Lt128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128Upper(dh, a.v0, b.v0); + ret.m1 = Lt128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Eq128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128(dh, a.v0, b.v0); + ret.m1 = Eq128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Eq128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128Upper(dh, a.v0, b.v0); + ret.m1 = Eq128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Ne128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128(dh, a.v0, b.v0); + ret.m1 = Ne128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Mask256 Ne128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128Upper(dh, a.v0, b.v0); + ret.m1 = Ne128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Min128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128(dh, a.v0, b.v0); + ret.v1 = Min128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Max128(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128(dh, a.v0, b.v0); + ret.v1 = Max128(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Min128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128Upper(dh, a.v0, b.v0); + ret.v1 = Min128Upper(dh, a.v1, b.v1); + return ret; +} + +template +HWY_INLINE Vec256 Max128Upper(Full256 d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128Upper(dh, a.v0, b.v0); + ret.v1 = Max128Upper(dh, a.v1, b.v1); + return ret; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/highway/hwy/ops/x86_128-inl.h new file mode 100644 index 0000000000..ba8d581984 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_128-inl.h @@ -0,0 +1,7432 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_GCC_ACTUAL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's emmintrin.h - see +// https://github.com/google/highway/issues/710 and pull/902 +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include +#include +#if HWY_TARGET == HWY_SSSE3 +#include // SSSE3 +#else +#include // SSE4 +#include // CLMUL +#endif +#include +#include +#include // memcpy + +#include "hwy/ops/shared-inl.h" + +#if HWY_IS_MSAN +#include +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw128 { + using type = __m128i; +}; +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +#if HWY_TARGET <= HWY_AVX3 + +namespace detail { + +// Template arg: sizeof(lane type) +template +struct RawMask128 {}; +template <> +struct RawMask128<1> { + using type = __mmask16; +}; +template <> +struct RawMask128<2> { + using type = __mmask8; +}; +template <> +struct RawMask128<4> { + using type = __mmask8; +}; +template <> +struct RawMask128<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +struct Mask128 { + using Raw = typename detail::RawMask128::type; + + static Mask128 FromBits(uint64_t mask_bits) { + return Mask128{static_cast(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 or below + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +template +using DFromV = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_si128()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_ps()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_pd()}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { + return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { + return Vec128{_mm_set1_epi32(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { + return Vec128{ + _mm_set1_epi64x(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{_mm_set1_epi32(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { + return Vec128{ + _mm_set1_epi64x(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{_mm_set1_ps(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const double t) { + return Vec128{_mm_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec128{_mm_undefined_si128()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_ps()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFF); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFFFF); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw)); +} +template +HWY_API float GetLane(const Vec128 v) { + return _mm_cvtss_f32(v.raw); +} +template +HWY_API uint64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) uint64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return static_cast(_mm_cvtsi128_si64(v.raw)); +#endif +} +template +HWY_API int64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) int64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return _mm_cvtsi128_si64(v.raw); +#endif +} +template +HWY_API double GetLane(const Vec128 v) { + return _mm_cvtsd_f64(v.raw); +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_si128(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{_mm_andnot_si128(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_TARGET <= HWY_AVX3 + const __m128i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Neg(hwy::FloatTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::NonFloatTag /*tag*/, const Vec128 v) { + return Zero(DFromV()) - v; +} + +} // namespace detail + +template +HWY_INLINE Vec128 Neg(const Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (reaches breakpoint) + const auto zero = Zero(DFromV()); + return Vec128{_mm_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec128{_mm_abs_epi8(v.raw)}; +#endif +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi16(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(DFromV(), mask); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(DFromV(), mask); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m128i out = _mm_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, VFromD{out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(DFromV()), sign)); +#endif +} + +// ================================================== MASK + +namespace detail { + +template +HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) { + // Workaround for MSAN not marking compressstore as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// ------------------------------ Mask logical + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS + +namespace detail { + +template +HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Not(const Mask128 m) { + // Flip only the valid bits. + // TODO(janwas): use _knot intrinsics if N >= 8. + return Xor(m, Mask128::FromBits((1ull << N) - 1)); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 or below + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Simd /* tag */, + const Mask128 v) { + return Vec128{v.raw}; +} + +#if HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const auto vmask = VecFromMask(DFromV(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); +} + +#else // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + return Vec128{_mm_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80. +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {1, 0, 7, 6}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0302, 0x0100, 0x0f0e, 0x0d0c}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0, 3, 6, 5}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0100, 0x0706, 0x0d0c, 0x0b0a}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {2, 1, 4, 7}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0504, 0x0302, 0x0908, 0x0f0e}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi8(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi16(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi32(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi64(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(Simd /* tag */, + const Mask128 v) { + return VecFromMask(v); +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(Simd(), VecFromMask(d, m))); +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const Simd d; + const RepartitionToNarrow d32; + const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of num! + const uint64_t bits = (num > 255) ? all : _bzhi_u64(all, num); + return Mask128::FromBits(bits); +#else + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +#endif +} + +template +using MFromD = decltype(FirstN(D(), 0)); + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{_mm_load_si128(reinterpret_cast(aligned))}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; +} + +template +HWY_API Vec128 LoadU(Full128 /* tag */, const T* HWY_RESTRICT p) { + return Vec128{_mm_loadu_si128(reinterpret_cast(p))}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; +} + +template +HWY_API Vec64 Load(Full64 /* tag */, const T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_loadl_epi64(reinterpret_cast(p))}; +#endif +} + +HWY_API Vec128 Load(Full64 /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec128{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec128{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif +} + +HWY_API Vec64 Load(Full64 /* tag */, + const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_load_sd(p)}; +#endif +} + +HWY_API Vec128 Load(Full32 /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec128{v}; +#else + return Vec128{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + constexpr size_t kSize = sizeof(T) * N; +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes(p, &v); // not same size + return Vec128{v}; +#else + int32_t bits = 0; + CopyBytes(p, &bits); // not same size + return Vec128{_mm_cvtsi32_si128(bits)}; +#endif +} + +// For < 128 bit, LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +HWY_API Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, + Simd /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, + Simd /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_pd(m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return Vec128{_mm_maskload_epi32(p_p, m.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return Vec128{_mm_maskload_epi64(p_p, m.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const float* HWY_RESTRICT p) { + const Vec128 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec128{_mm_maskload_ps(p, mi.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const double* HWY_RESTRICT p) { + const Vec128 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec128{_mm_maskload_pd(p, mi.raw)}; +} + +// There is no maskload_epi8/16, so blend instead. +template // 1 or 2 bytes +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#else // <= SSE4 + +// Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#endif + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template +HWY_API void Store(Vec64 v, Full64 /* tag */, T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec128 v, Full64 /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec64 v, Full64 /* tag */, + double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); // not same size +} +HWY_API void Store(const Vec128 v, Full32 /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ BlendedStore + +namespace detail { + +// There is no maskload_epi8/16 with which we could safely implement +// BlendedStore. Manual blending is also unsafe because loading a full vector +// that crosses the array end causes asan faults. Resort to scalar code; the +// caller should instead use memcpy, assuming m is FirstN(d, n). +template +HWY_API void ScalarMaskedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + const RebindToSigned di; // for testing mask if T=bfloat16_t. + using TI = TFromD; + alignas(16) TI buf[N]; + alignas(16) TI mask[N]; + Store(BitCast(di, v), di, buf); + Store(BitCast(di, VecFromMask(d, m)), di, mask); + for (size_t i = 0; i < N; ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi8(p, m.raw, v.raw); +} +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd, float* HWY_RESTRICT p) { + _mm_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd, double* HWY_RESTRICT p) { + _mm_mask_storeu_pd(p, m.raw, v.raw); +} + +#elif HWY_TARGET == HWY_AVX2 + +template // 1 or 2 bytes +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + detail::ScalarMaskedStore(v, m, d, p); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd d, float* HWY_RESTRICT p) { + using T = float; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + const Vec128, N> mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm_maskstore_ps(p, mi.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd d, double* HWY_RESTRICT p) { + using T = double; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + const Vec128, N> mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm_maskstore_pd(p, mi.raw, v.raw); +} + +#else // <= SSE4 + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). + detail::ScalarMaskedStore(v, m, d, p); +} + +#endif // SSE4 + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epu32(a.raw, b.raw)}; +} + +#if HWY_TARGET == HWY_SSSE3 + +template // N=1 or 2 +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Set(Simd(), + static_cast(GetLane(a)) * GetLane(b)); +} +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) int32_t a_lanes[4]; + alignas(16) int32_t b_lanes[4]; + const Full128 di32; + Store(a, di32, a_lanes); + Store(b, di32, b_lanes); + alignas(16) int64_t mul[2]; + mul[0] = static_cast(a_lanes[0]) * b_lanes[0]; + mul[1] = static_cast(a_lanes[2]) * b_lanes[2]; + return Load(Full128(), mul); +} + +#else // HWY_TARGET == HWY_SSSE3 + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epi32(a.raw, b.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // Not as inefficient as it looks: _mm_mullo_epi32 has 10 cycle latency. + // 64-bit right shift would also work but also needs port 5, so no benefit. + // Notation: x=don't care, z=0. + const __m128i a_x3x1 = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x2x0 = MulEven(a, b); + const __m128i b_x3x1 = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x3x1 = + MulEven(Vec128{a_x3x1}, Vec128{b_x3x1}); + // We could _mm_slli_epi64 by 32 to get 3z1z and OR with z2z0, but generating + // the latter requires one more instruction or a constant. + const __m128i mul_20 = + _mm_shuffle_epi32(mullo_x2x0.raw, _MM_SHUFFLE(2, 0, 2, 0)); + const __m128i mul_31 = + _mm_shuffle_epi32(mullo_x3x1.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(mul_20, mul_31)}; +#else + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + // Same as unsigned; avoid duplicating the SSSE3 code. + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +// ------------------------------ RotateRight (ShiftRight, Or) + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<15>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<31>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Vec128{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 + return VecFromMask(v < Zero(d)); +#else + // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift + // avoids generating a zero. + const RepartitionToNarrow d32; + const auto sign = ShiftRight<31>(BitCast(d32, v)); + return Vec128{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_abs_epi64(v.raw)}; +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srai_epi64(v.raw, kBits)}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ ZeroIfNegative (BroadcastSignBit) +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + static_assert(IsFloat(), "Only works for float"); + const DFromV d; +#if HWY_TARGET == HWY_SSSE3 + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#else + const auto mask = MaskFromVec(v); // MSB is sufficient for BLENDVPS +#endif + return IfThenElse(mask, Zero(d), v); +} + +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(const Vec128 v, + const Vec128 yes, + const Vec128 no) { + // int8: IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + return BitCast(d, IfThenElse(MaskFromVec(BitCast(df, v)), BitCast(df, yes), + BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { + return Vec64{_mm_mul_sd(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator/(const Vec64 a, const Vec64 b) { + return Vec64{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ss(v.raw)}; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ss(v.raw)}; +} +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec64 Sqrt(const Vec64 v) { + return Vec64{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ss(v.raw)}; +} + +// ------------------------------ Min (Gt, IfThenElse) + +namespace detail { + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MinU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epu64(a.raw, b.raw)}; +#else + return detail::MinU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +namespace detail { +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MaxU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epu64(a.raw, b.raw)}; +#else + return detail::MaxU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_pd(a.raw, b.raw)}; +} + +// ================================================== MEMORY (2) + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(runtime/int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_epi32(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_epi32(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_epi64(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_epi64(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, 8); + } +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_ps(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_ps(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_pd(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_pd(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, 8); + } +} +#else // HWY_TARGET <= HWY_AVX3 + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +#else + +namespace detail { + +template +HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<4> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<4> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<8> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<8> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec128 GatherOffset(Simd d, const T* HWY_RESTRICT base, + const Vec128 offset) { + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec128 GatherIndex(Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_ps(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_ps(base, index.raw, 4)}; +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_pd(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_pd(base, index.raw, 8)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +// Returns upper/lower half of a vector. +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return LowerHalf(Simd(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec128{_mm_slli_si128(v.raw, kBytes)}; +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128 vfull{v.raw}; + v = Vec128{IfThenElseZero(FirstN(Full128(), N), vfull).raw}; + } + return Vec128{_mm_srli_si128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec64 UpperHalf(Half> /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_epi64(v.raw, v.raw)}; +} +HWY_API Vec128 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec128{_mm_movehl_ps(v.raw, v.raw)}; +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128{upper.raw}; +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const int pair = _mm_extract_epi16(v.raw, kLane / 2); + constexpr int kShift = kLane & 1 ? 8 : 0; + return static_cast((pair >> kShift) & 0xFF); +#else + return static_cast(_mm_extract_epi8(v.raw, kLane) & 0xFF); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast(_mm_extract_epi16(v.raw, kLane) & 0xFFFF); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + return static_cast(_mm_extract_epi32(v.raw, kLane)); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + return static_cast(_mm_extract_epi64(v.raw, kLane)); +#endif +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) float lanes[4]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + // Bug in the intrinsic, returns int but should be float. + const int32_t bits = _mm_extract_ps(v.raw, kLane); + float ret; + CopySameSize(&bits, &ret); + return ret; +#endif +} + +// There is no extract_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane == 0, "Lane index out of bounds"); + return GetLane(v); +} + +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane < 2, "Lane index out of bounds"); + const Half> dh; + return kLane == 0 ? GetLane(v) : GetLane(UpperHalf(dh, v)); +} + +} // namespace detail + +// Requires one overload per vector length because ExtractLane<3> may be a +// compile error if it calls _mm_extract_epi64. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128{_mm_insert_epi8(v.raw, t, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{_mm_insert_epi16(v.raw, t, kLane)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + const DFromV d; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128{_mm_insert_epi32(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128{_mm_insert_epi64(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + alignas(16) float lanes[4]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128{_mm_insert_ps(v.raw, _mm_set_ss(t), kLane << 4)}; +#endif +} + +// There is no insert_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane == 0, "Lane index out of bounds"); + return Set(DFromV(), t); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV d; + const Vec128 vt = Set(d, t); + if (kLane == 0) { + return Vec128{_mm_shuffle_pd(vt.raw, v.raw, 2)}; + } + return Vec128{_mm_shuffle_pd(v.raw, vt.raw, 0)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls _mm_insert_epi64. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full128 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec128{_mm_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +template > +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full128 d_full8; + using V8 = VFromD; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, N)))); +#endif + +#if HWY_TARGET <= HWY_AVX2 + (void)d; + return Indices128{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif +} + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#else + (void)d; +#endif + + // No change - even without AVX3, we can shuffle+blend. + return Indices128{vec.raw}; +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + const DFromV d; + const RebindToFloat df; + const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; + return BitCast(d, perm); +#else + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; +#else + const DFromV df; + const RebindToSigned di; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +#endif +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const Full128 d; + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const Full128 df; + return BitCast( + d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128 di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + vidx += vidx; // bit1 is the decider (unusual) + return Vec128{_mm_permutevar_pd(v.raw, vidx.raw)}; +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128 d; + const Full128 di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Full64 /* tag */, const Vec128 v) { + return Vec128{Shuffle2301(Vec128{v.raw}).raw}; +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + if (N == 1) return v; + if (N == 2) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); + } + const RebindToSigned di; + alignas(16) constexpr int16_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + const Vec128 idx = Load(di, kReverse + (N == 8 ? 0 : 4)); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +#endif +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return v; +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + alignas(16) const T kShuffle[16] = {1, 0, 3, 2, 5, 4, 7, 6, + 9, 8, 11, 10, 13, 12, 15, 14}; + return TableLookupBytes(v, Load(d, kShuffle)); +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToSigned di; + // 4x 16-bit: a single shufflelo suffices. + if (N == 4) { + return BitCast(d, Vec128{_mm_shufflelo_epi16( + BitCast(di, v).raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET <= HWY_AVX3 + alignas(16) constexpr int16_t kReverse4[8] = {3, 2, 1, 0, 7, 6, 5, 4}; + const Vec128 idx = Load(di, kReverse4); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +// 4x 32-bit: use Shuffle0123 +template +HWY_API Vec128 Reverse4(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec128 idx = Load(di, kReverse8); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec128 Reverse8(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Additional overload for the optional tag (also for 256/512). +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +// Full +template > +HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, + Vec128 lo_half) { + const Half d2; + const RebindToUnsigned du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 ZeroExtendVector(hwy::NonFloatTag /*tag*/, + Full128 /* d */, Vec64 lo) { + return Vec128{_mm_move_epi64(lo.raw)}; +} + +template +HWY_INLINE Vec128 ZeroExtendVector(hwy::FloatTag /*tag*/, Full128 d, + Vec64 lo) { + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(Half(), lo))); +} + +} // namespace detail + +template +HWY_API Vec128 ZeroExtendVector(Full128 d, Vec64 lo) { + return detail::ZeroExtendVector(hwy::IsFloatTag(), d, lo); +} + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec128 ConcatUpperUpper(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API Vec128 ConcatLowerUpper(Full128 d, const Vec128 hi, + const Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec128 ConcatUpperLower(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(dd, lo).raw, BitCast(dd, hi).raw, + _MM_SHUFFLE2(1, 0))}); +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _pd can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128 ConcatUpperLower(Full128 d, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET == HWY_SSSE3 + (void)d; + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 2, 1, 0))}; +#else + // _mm_shuffle_ps has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + const RepartitionToWide dd; + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128 ConcatUpperLower(Full128 /* tag */, + Vec128 hi, Vec128 lo) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_shuffle_pd(lo.raw, hi.raw, _MM_SHUFFLE2(1, 0))}; +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return Vec128{_mm_blend_pd(hi.raw, lo.raw, 1)}; +#endif +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API Vec128 ConcatLowerLower(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatUpperUpper(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); + return Vec128{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API Vec64 ConcatOdd(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[8] = {1, 3, 5, 7}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactOddU8)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template +HWY_API Vec32 ConcatOdd(Simd d, Vec32 hi, Vec32 lo) { + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[4] = {1, 3}; + const Vec32 shuf = BitCast(d, Load(Full32(), kCompactOddU8)); + const Vec32 L = TableLookupBytes(lo, shuf); + const Vec32 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const Repartition dw; + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); + return Vec128{_mm_packs_epi32(uL.raw, uH.raw)}; +} + +// 16-bit x4 +template +HWY_API Vec64 ConcatOdd(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU16[8] = {2, 3, 6, 7}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactOddU16)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(3, 1, 3, 1))}); +} +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, + Vec128 lo) { + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; +} + +// Any type x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec128 mask = Set(dw, 0x00FF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return Vec128{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API Vec64 ConcatEven(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[8] = {0, 2, 4, 6}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactEvenU8)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template +HWY_API Vec32 ConcatEven(Simd d, Vec32 hi, Vec32 lo) { + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[4] = {0, 2}; + const Vec32 shuf = BitCast(d, Load(Full32(), kCompactEvenU8)); + const Vec32 L = TableLookupBytes(lo, shuf); + const Vec32 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { +#if HWY_TARGET <= HWY_SSE4 + // Isolate lower 16 bits per u32 so we can pack. + const Repartition dw; + const Vec128 mask = Set(dw, 0x0000FFFF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return Vec128{_mm_packus_epi32(uL.raw, uH.raw)}; +#else + // packs_epi32 saturates 0x8000 to 0x7FFF. Instead ConcatEven within the two + // inputs, then concatenate them. + alignas(16) const T kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + const Vec128 shuf = BitCast(d, Load(d, kCompactEvenU16)); + const Vec128 L = TableLookupBytes(lo, shuf); + const Vec128 H = TableLookupBytes(hi, shuf); + return ConcatLowerLower(d, H, L); +#endif +} + +// 16-bit x4 +template +HWY_API Vec64 ConcatEven(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU16[8] = {0, 1, 4, 5}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactEvenU16)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, + Vec128 lo) { + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +#else + return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x55)}; +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i odd = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128i even = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(even, odd)}; +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _ps can do 3/cycle. + const DFromV d; + const RebindToFloat df; + return BitCast(d, Vec128{_mm_blend_ps(BitCast(df, a).raw, + BitCast(df, b).raw, 5)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const RebindToFloat dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd( + BitCast(dd, b).raw, BitCast(dd, a).raw, _MM_SHUFFLE2(1, 0))}); +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, a).raw, + BitCast(dd, b).raw, 1)}); +#endif +} + +template +HWY_API Vec128 OddEven(Vec128 a, Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // SHUFPS must fill the lower half of the output from one input, so we + // need another shuffle. Unpack avoids another immediate byte. + const __m128 odd = _mm_shuffle_ps(a.raw, a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128 even = _mm_shuffle_ps(b.raw, b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_ps(even, odd)}; +#else + return Vec128{_mm_blend_ps(a.raw, b.raw, 5)}; +#endif +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +namespace detail { +#if HWY_TARGET > HWY_AVX3 // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // See comment below. + const Vec128 bits0{_mm_cvtps_epi32(BitCast(df, f0).raw)}; + const Vec128 bits1{_mm_cvtps_epi32(BitCast(df, f1).raw)}; + return Vec128, N>{_mm_packus_epi32(bits0.raw, bits1.raw)}; +} + +// Same, for 32-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. cvt instead of cvtt should be equivalent, but avoids test + // failure under GCC 10.2.1. + return Vec128, N>{_mm_cvtps_epi32(_mm_castsi128_ps(f.raw))}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { + return Vec128{_mm_sll_epi16(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return v * Pow2(bits); +#else + return Vec128{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sll_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128 out0{_mm_sll_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_sll_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128(), out1, out0); +#else + return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 Shl(hwy::UnsignedTag /*tag*/, Vec64 v, + Vec64 bits) { + return Vec64{_mm_sll_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for SSSE3/SSE4 or 16-bit. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; +#else + const Simd d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { + return Vec128{_mm_srl_epi16(in.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const Simd d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128 mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = BitCast(d32, MulEven(in31, mul31)); // 3 ? 1 ? + const Vec128 out = OddEven(out31, BitCast(d32, out20)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { + return Vec128{_mm_srl_epi32(in.raw, bits.raw)}; +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128 out0{_mm_srl_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_srl_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128(), out1, out0); +#else + return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_srl_epi64(v.raw, bits.raw)}; +} + +#if HWY_TARGET > HWY_AVX3 // AVX2 or older +namespace detail { + +// Also used in x86_256-inl.h. +template +HWY_INLINE V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sra_epi16(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sra_epi32(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template > +HWY_API V ReorderWidenMulAccumulate(Simd df32, VFromD a, + VFromD b, const V sum0, V& sum1) { + // TODO(janwas): _mm_dpbf16_ps when available + const RebindToUnsigned du32; + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip + // leads to the odd/even order that RearrangeToOddPlusEven prefers. + using VU32 = VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd /*d32*/, Vec128 a, + Vec128 b, const Vec128 sum0, + Vec128& /*sum1*/) { + return sum0 + Vec128{_mm_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return Vec128{_mm_unpacklo_epi8(v.raw, zero)}; +#else + return Vec128{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return Vec128{_mm_cvtepu16_epi32(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; +#else + return Vec128{_mm_cvtepu32_epi64(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return Vec128{_mm_unpacklo_epi16(u16, zero)}; +#else + return Vec128{_mm_cvtepu8_epi32(v.raw)}; +#endif +} + +// Unsigned to signed: same plus cast. +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<8>(Vec128{_mm_unpacklo_epi8(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi8_epi16(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<16>(Vec128{_mm_unpacklo_epi16(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi16_epi32(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<32>(Vec128{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi32_epi64(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(Vec128{x4}); +#else + return Vec128{_mm_cvtepi8_epi32(v.raw)}; +#endif +} + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 Vec128 PromoteTo(Simd df32, + const Vec128 v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec128{_mm_cvtph_ps(v.raw)}; +#endif +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtps_pd(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const Simd di32; + const Simd du16; + const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); + const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); + const auto clamped = Or(zero_if_neg, too_big); + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16, kLower2Bytes); + return Vec128{TableLookupBytes(BitCast(du16, clamped), lo2).raw}; +#else + return Vec128{_mm_packus_epi32(v.raw, v.raw)}; +#endif +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi32(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packus_epi16(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi16(v.raw, v.raw)}; +} + +// Work around MSVC warning for _mm_cvtps_ph (8 is actually a valid immediate). +// clang-cl requires a non-empty string, so we 'ignore' the irrelevant -Wmain. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + // TODO(janwas): _mm_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + // TODO(janwas): _mm_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Full128 /*d16*/, + Vec128 a, Vec128 b) { + return Vec128{_mm_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// For well-defined float->int demotion in all x86_*-inl.h. + +template +HWY_INLINE auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) + -> decltype(Zero(d)) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +// For ConvertTo float->int of same size, clamping before conversion would +// change the result because the max integer value is not exactly representable. +// Instead detect the overflow result after conversion and fix it. +template > +HWY_INLINE auto FixConversionOverflow(DI di, VFromD original, + decltype(Zero(di).raw) converted_raw) + -> VFromD { + // Combinations of original and output sign: + // --: normal <0 or -huge_val to 80..00: OK + // -+: -0 to 0 : OK + // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF + // ++: normal >0 : OK + const auto converted = VFromD{converted_raw}; + const auto sign_wrong = AndNot(BitCast(di, original), converted); +#if HWY_COMPILER_GCC_ACTUAL + // Critical GCC 11 compiler bug (possibly also GCC 10): omits the Xor; also + // Add() if using that instead. Work around with one more instruction. + const RebindToUnsigned du; + const VFromD mask = BroadcastSignBit(sign_wrong); + const VFromD max = BitCast(di, ShiftRight<1>(BitCast(du, mask))); + return IfVecThenElse(mask, max, converted); +#else + return Xor(converted, BroadcastSignBit(sign_wrong)); +#endif +} + +} // namespace detail + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto clamped = detail::ClampF64ToI32Max(Simd(), v); + return Vec128{_mm_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const Simd d32; + const Simd d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + static_assert(!IsSigned() && !IsSigned(), "Unsigned only"); + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d8; + alignas(16) static constexpr uint8_t kMap[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + return LowerHalf(LowerHalf(LowerHalf(TableLookupBytes(v, Load(d8, kMap))))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d16; + alignas(16) static constexpr uint16_t kMap[8] = { + 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d16, kMap)))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x88)}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + alignas(16) static constexpr uint8_t kMap[16] = { + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu, + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d, kMap)))); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_ps(v.raw)}; +} + +template +HWY_API Vec128 ConvertTo(HWY_MAYBE_UNUSED Simd df, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +template +HWY_API Vec128 ConvertTo(Simd dd, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec128{_mm_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +template +HWY_API Vec128 ConvertTo(HWY_MAYBE_UNUSED Simd dd, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double128_fast = [&dd](VU w) HWY_ATTR { + w = Or(w, VU{detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double128_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double128_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(const Simd di, + const Vec128 v) { + return detail::FixConversionOverflow(di, v, _mm_cvttps_epi32(v.raw)); +} + +// Full (partial handled below) +HWY_API Vec128 ConvertTo(Full128 di, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 + return detail::FixConversionOverflow(di, v, _mm_cvttpd_epi64(v.raw)); +#elif HWY_ARCH_X86_64 + const __m128i i0 = _mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw)); + const Half> dd2; + const __m128i i1 = _mm_cvtsi64_si128(_mm_cvttsd_si64(UpperHalf(dd2, v).raw)); + return detail::FixConversionOverflow(di, v, _mm_unpacklo_epi64(i0, i1)); +#else + using VI = VFromD; + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} +HWY_API Vec64 ConvertTo(Full64 di, const Vec64 v) { + // Only need to specialize for non-AVX3, 64-bit (single scalar op) +#if HWY_TARGET > HWY_AVX3 && HWY_ARCH_X86_64 + const Vec64 i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; + return detail::FixConversionOverflow(di, v, i0.raw); +#else + (void)di; + const auto full = ConvertTo(Full128(), Vec128{v.raw}); + return Vec64{full.raw}; +#endif +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const Simd di; + return detail::FixConversionOverflow(di, v, _mm_cvtps_epi32(v.raw)); +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +#if HWY_TARGET == HWY_SSSE3 + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + // Rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const Simd df; + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#else + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +#endif // !HWY_SSSE3 + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask128{_mm_cmpunord_ps(v.raw, v.raw)}; +#endif +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask128{_mm_cmpunord_pd(v.raw, v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ps_mask(v.raw, 0x18)}; +} +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ps_mask(v.raw, 0x99)}); +} +template +HWY_API Mask128 IsFinite(const Vec128 v) { + return Not(Mask128{_mm_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template +HWY_API Mask128 IsInf(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD exp = + BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenc_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x00)}; +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x11)}; +} + +#endif // !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128 vbits{_mm_cvtsi32_si128(static_cast(mask_bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask128::FromBits(mask_bits); +#else + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits(d, mask_bits); +#endif +} + +template +struct CompressIsPartition { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 supports native compress, but a table-based approach allows + // 'partitioning' (also moving mask=false lanes to the top), which helps + // vqsort. This is only feasible for eight or less lanes, i.e. sizeof(T) == 8 + // on AVX3. For simplicity, we only use tables for 64-bit lanes (not AVX3 + // u32x8 etc.). + enum { value = (sizeof(T) == 8) }; +#else + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +#endif +}; + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = (1 << N) - 1; + bits[0] = static_cast(bits[0] & mask_bits); + } + + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +template +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return PopCount(mask_bits); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return mask_bits == 0; +} + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + // Cannot use _kortestc because we may have less than 8 mask bits. + return mask_bits == (1u << N) - 1; +} + +// ------------------------------ Compress + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return Vec128{_mm_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + HWY_DASSERT(mask.raw < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Full128 d; + const Repartition d8; + const auto index = Load(d8, u8_indices + 16 * mask.raw); + return BitCast(d, TableLookupBytes(BitCast(d8, v), index)); +} + +// ------------------------------ CompressNot (Compress) + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // See CompressIsPartition, PrintCompressNot64x2NibbleTables + alignas(16) constexpr uint64_t packed_array[16] = {0x00000010, 0x00000001, + 0x00000010, 0x00000010}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm_permutexvar_epi64 will ignore the upper bits. + const Full128 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(16) constexpr uint64_t shifts[2] = {0, 4}; + const auto indices = Indices128{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + float* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + double* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + // AVX-512 already does the blending at no extra cost (latency 11, + // rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { + // We're relying on the mask to blend. Clear the undefined upper bits. + if (N != 16 / sizeof(T)) { + m = And(m, FirstN(d, N)); + } + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + const Vec128 compressed = Compress(v, m); +#if HWY_MEM_OPS_MIGHT_FAULT + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(16) T buf[N]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 or below + +// ------------------------------ StoreMaskBits + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast(static_cast(mask_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + const Simd d; + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_ps(sign_bits.raw)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_pd(sign_bits.raw)); +} + +// Returns the lowest N of the _mm_movemask* bits. +template +constexpr uint64_t OnlyActive(uint64_t mask_bits) { + return ((N * sizeof(T)) == 16) ? mask_bits : mask_bits & ((1ull << N) - 1); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { + constexpr uint64_t kAllBits = + detail::OnlyActive((1ull << (16 / sizeof(T))) - 1); + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const Simd d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const Simd d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return detail::CompressBits(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNotBits(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Shuffle2301(v10); +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Shuffle2301(v10)); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Shuffle2301(v10)); +} + +// N=4 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +// u8, N=8, N=16: +HWY_API Vec64 SumOfLanes(hwy::SizeTag<1> /* tag */, Vec64 v) { + const Full64 d; + return Set(d, static_cast(GetLane(SumsOf8(v)) & 0xFF)); +} +HWY_API Vec128 SumOfLanes(hwy::SizeTag<1> /* tag */, + Vec128 v) { + const Full128 d; + Vec128 sums = SumOfLanes(hwy::SizeTag<8>(), SumsOf8(v)); + return Set(d, static_cast(GetLane(sums) & 0xFF)); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto is_neg = v < Zero(d); + + // Sum positive and negative lanes separately, then combine to get the result. + const auto positive = SumsOf8(BitCast(du, IfThenZeroElse(is_neg, v))); + const auto negative = SumsOf8(BitCast(du, IfThenElseZero(is_neg, Abs(v)))); + return Set(d, static_cast(GetLane( + SumOfLanes(hwy::SizeTag<8>(), positive - negative)) & + 0xFF)); +} + +#if HWY_TARGET <= HWY_SSE4 +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + using V = decltype(v); + return Broadcast<0>(V{_mm_minpos_epu16(v.raw)}); +} +HWY_API Vec64 MinOfLanes(hwy::SizeTag<1> /* tag */, Vec64 v) { + const Full64 d; + const Full128 d16; + return TruncateTo(d, MinOfLanes(hwy::SizeTag<2>(), PromoteTo(d16, v))); +} +HWY_API Vec128 MinOfLanes(hwy::SizeTag<1> tag, + Vec128 v) { + const Half> d; + Vec64 result = + Min(MinOfLanes(tag, UpperHalf(d, v)), MinOfLanes(tag, LowerHalf(d, v))); + return Combine(DFromV(), result, result); +} + +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> tag, Vec128 v) { + const Vec128 m(Set(DFromV(), LimitsMax())); + return m - MinOfLanes(tag, m - v); +} +HWY_API Vec64 MaxOfLanes(hwy::SizeTag<1> tag, Vec64 v) { + const Vec64 m(Set(DFromV(), LimitsMax())); + return m - MinOfLanes(tag, m - v); +} +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<1> tag, Vec128 v) { + const Vec128 m(Set(DFromV(), LimitsMax())); + return m - MinOfLanes(tag, m - v); +} +#elif HWY_TARGET == HWY_SSSE3 +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128 v) { + const DFromV d; + const RepartitionToWide d16; + const RepartitionToWide d32; + Vec128 vm = Max(v, Reverse2(d, v)); + vm = Max(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); + vm = Max(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); + if (N > 8) { + const RepartitionToWide d64; + vm = Max(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); + } + return vm; +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128 v) { + const DFromV d; + const RepartitionToWide d16; + const RepartitionToWide d32; + Vec128 vm = Min(v, Reverse2(d, v)); + vm = Min(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); + vm = Min(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); + if (N > 8) { + const RepartitionToWide d64; + vm = Min(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); + } + return vm; +} +#endif + +// Implement min/max of i8 in terms of u8 by toggling the sign bit. +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<1> tag, + const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto mask = SignBit(du); + const auto vu = Xor(BitCast(du, v), mask); + return BitCast(d, Xor(MinOfLanes(tag, vu), mask)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<1> tag, + const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto mask = SignBit(du); + const auto vu = Xor(BitCast(du, v), mask); + return BitCast(d, Xor(MaxOfLanes(tag, vu), mask)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Lt128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Eq128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template > +HWY_INLINE V Ne128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template > +HWY_INLINE V Lt128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template > +HWY_INLINE V Eq128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template > +HWY_INLINE V Ne128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template > +HWY_API MFromD Lt128(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template > +HWY_API MFromD Eq128(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template > +HWY_API MFromD Ne128(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template > +HWY_API MFromD Lt128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Eq128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Ne128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/highway/hwy/ops/x86_256-inl.h b/third_party/highway/hwy/ops/x86_256-inl.h new file mode 100644 index 0000000000..3539520adf --- /dev/null +++ b/third_party/highway/hwy/ops/x86_256-inl.h @@ -0,0 +1,5548 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when +// compiling for that target. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +// Must come before HWY_COMPILER_CLANGCL +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include +// avxintrin defines __m256i and must come before avx2intrin. +#include +#include // _pext_u64 +#include +#include +#include +#endif // HWY_COMPILER_CLANGCL + +#include +#include +#include // memcpy + +#if HWY_IS_MSAN +#include +#endif + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw256 { + using type = __m256i; +}; +template <> +struct Raw256 { + using type = __m256; +}; +template <> +struct Raw256 { + using type = __m256d; +}; + +} // namespace detail + +template +class Vec256 { + using Raw = typename detail::Raw256::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +#if HWY_TARGET <= HWY_AVX3 + +namespace detail { + +// Template arg: sizeof(lane type) +template +struct RawMask256 {}; +template <> +struct RawMask256<1> { + using type = __mmask32; +}; +template <> +struct RawMask256<2> { + using type = __mmask16; +}; +template <> +struct RawMask256<4> { + using type = __mmask8; +}; +template <> +struct RawMask256<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +struct Mask256 { + using Raw = typename detail::RawMask256::type; + + static Mask256 FromBits(uint64_t mask_bits) { + return Mask256{static_cast(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 + +// FF..FF or 0. +template +struct Mask256 { + typename detail::Raw256::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +template +using Full256 = Simd; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } +HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } +HWY_INLINE __m256i BitCastToInteger(__m256d v) { + return _mm256_castpd_si256(v); +} + +template +HWY_INLINE Vec256 BitCastToByte(Vec256 v) { + return Vec256{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } +}; + +template +HWY_INLINE Vec256 BitCastFromByte(Full256 /* tag */, Vec256 v) { + return Vec256{BitCastFromInteger256()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_si256()}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_ps()}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { + return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { + return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { + return Vec256{_mm256_set1_epi32(static_cast(t))}; +} +HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { + return Vec256{ + _mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { + return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { + return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { + return Vec256{_mm256_set1_epi32(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { + return Vec256{ + _mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const float t) { + return Vec256{_mm256_set1_ps(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const double t) { + return Vec256{_mm256_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec256 Undefined(Full256 /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec256{_mm256_undefined_si256()}; +} +HWY_API Vec256 Undefined(Full256 /* tag */) { + return Vec256{_mm256_undefined_ps()}; +} +HWY_API Vec256 Undefined(Full256 /* tag */) { + return Vec256{_mm256_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 And(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_and_ps(a.raw, b.raw)}; +} +HWY_API Vec256 And(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_si256(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(const Vec256 not_mask, + const Vec256 mask) { + return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(const Vec256 not_mask, + const Vec256 mask) { + return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_or_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template +HWY_API Vec256 Not(const Vec256 v) { + using TU = MakeUnsigned; +#if HWY_TARGET <= HWY_AVX3 + const __m256i vu = BitCast(Full256(), v).raw; + return BitCast(Full256(), + Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(Full256(), Vec256{_mm256_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template +HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 PopulationCount(Vec256 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ CopySign + +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Full256 d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m256i out = _mm256_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(Full256()), sign)); +#endif +} + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return Vec256{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec256{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec256 ZeroIfNegative(const Vec256 v) { + static_assert(IsSigned(), "Only for float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask256 And(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Or(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Not(const Mask256 m) { + // Flip only the valid bits. + constexpr size_t N = 32 / sizeof(T); + return Xor(m, Mask256::FromBits((1ull << N) - 1)); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { + return Vec256{v.raw}; +} + +// ------------------------------ IfThenElse + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(const Mask256 mask, + const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(const Mask256 mask, + const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + static_assert(IsSigned(), "Only for float"); + const auto zero = Zero(Full256()); + // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes + return IfThenElse(MaskFromVec(v), zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<1> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<2> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<4> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<8> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; +} +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi8(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi16(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi32(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi64(v.raw)}; +} + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; +} + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; +} + +template +HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { + return VecFromMask(v); +} + +#else // AVX2 + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 d_to, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(d_to, VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Not(a == b); +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 +// to perform an unsigned comparison instead of the intended signed. Workaround +// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy +#if HWY_COMPILER_GCC != 0 && HWY_COMPILER_GCC < 930 +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 +#else +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 +#endif + +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > + reinterpret_cast(b.raw))}; +#else + return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; +#endif +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, Vec256 b) { + const Full256 du; + const RebindToSigned di; + const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); + return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); +} + +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; +} + +} // namespace detail + +template +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return b > a; +} + +template +HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { + return b >= a; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask256 FirstN(const Full256 d, size_t n) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + constexpr size_t N = 32 / sizeof(T); +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of n! + return Mask256::FromBits((n > 255) ? all : _bzhi_u64(all, n)); +#else + const uint32_t all = static_cast((1ull << N) - 1); + // BZHI only looks at the lower 8 bits of n! + return Mask256::FromBits( + (n > 255) ? all : _bzhi_u32(all, static_cast(n))); +#endif // HWY_ARCH_X86_64 +#else + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(n))); +#endif +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_add_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(const Vec256 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (wrong result) + const auto zero = Zero(Full256()); + return Vec256{_mm256_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec256{_mm256_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi16(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. + +HWY_API Vec256 Abs(const Vec256 v) { + const Vec256 mask{_mm256_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(Full256(), mask); +} +HWY_API Vec256 Abs(const Vec256 v) { + const Vec256 mask{_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(Full256(), mask); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec256 MulFixedPoint15(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ RotateRight + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return VecFromMask(v < Zero(Full256())); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { +#if HWY_TARGET == HWY_AVX2 + return VecFromMask(v < Zero(Full256())); +#else + return Vec256{_mm256_srai_epi64(v.raw, 63)}; +#endif +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srai_epi64(v.raw, kBits)}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +HWY_API Vec256 Abs(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_abs_epi64(v.raw)}; +#else + const auto zero = Zero(Full256()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, + Vec256 no) { + // int8: AVX2 IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Full256 d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Full256 d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + const MFromD msb = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Neg (Xor, Sub) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec256 Neg(hwy::FloatTag /*tag*/, const Vec256 v) { + return Xor(v, SignBit(Full256())); +} + +// Not floating-point +template +HWY_INLINE Vec256 Neg(hwy::NonFloatTag /*tag*/, const Vec256 v) { + return Zero(Full256()) - v; +} + +} // namespace detail + +template +HWY_API Vec256 Neg(const Vec256 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_div_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + return Vec256{_mm256_rcp_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 NegMulAdd(const Vec256 mul, + const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 NegMulSub(const Vec256 mul, + const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{_mm256_sqrt_ps(v.raw)}; +} +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{_mm256_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { + return Vec256{_mm256_rsqrt_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ------------------------------ Floating-point classification + +HWY_API Mask256 IsNaN(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask256{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} +HWY_API Mask256 IsNaN(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask256{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +HWY_API Mask256 IsInf(const Vec256 v) { + return Mask256{_mm256_fpclass_ps_mask(v.raw, 0x18)}; +} +HWY_API Mask256 IsInf(const Vec256 v) { + return Mask256{_mm256_fpclass_pd_mask(v.raw, 0x18)}; +} + +HWY_API Mask256 IsFinite(const Vec256 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256{_mm256_fpclass_ps_mask(v.raw, 0x99)}); +} +HWY_API Mask256 IsFinite(const Vec256 v) { + return Not(Mask256{_mm256_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template +HWY_API Mask256 IsInf(const Vec256 v) { + static_assert(IsFloat(), "Only for float"); + const Full256 d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + static_assert(IsFloat(), "Only for float"); + const Full256 d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD exp = + BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec256{ + _mm256_load_si256(reinterpret_cast(aligned))}; +} +HWY_API Vec256 Load(Full256 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ps(aligned)}; +} +HWY_API Vec256 Load(Full256 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_pd(aligned)}; +} + +template +HWY_API Vec256 LoadU(Full256 /* tag */, const T* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_si256(reinterpret_cast(p))}; +} +HWY_API Vec256 LoadU(Full256 /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ps(p)}; +} +HWY_API Vec256 LoadU(Full256 /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi64(m.raw, p)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_ps(m.raw, p)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_pd(m.raw, p)}; +} + +#else // AVX2 + +// There is no maskload_epi8/16, so blend instead. +template * = nullptr> +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return Vec256{_mm256_maskload_epi32(pi, m.raw)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return Vec256{_mm256_maskload_epi64(pi, m.raw)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_ps(p, mi.raw)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_pd(p, mi.raw)}; +} + +#endif + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API Vec256 LoadDup128(Full256 /* tag */, const T* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note + // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the + // upper half undefined) is fine because we're overwriting that anyway. + // This workaround seems in turn to generate incorrect code in MSVC 2022 + // (19.31), so use broadcastsi128 there. + const __m128i v128 = LoadU(Full128(), p).raw; + return Vec256{ + _mm256_inserti128_si256(_mm256_castsi128_si256(v128), v128, 1)}; +#else + return Vec256{_mm256_broadcastsi128_si256(LoadU(Full128(), p).raw)}; +#endif +} +HWY_API Vec256 LoadDup128(Full256 /* tag */, + const float* const HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const __m128 v128 = LoadU(Full128(), p).raw; + return Vec256{ + _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; +#else + return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; +#endif +} +HWY_API Vec256 LoadDup128(Full256 /* tag */, + const double* const HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const __m128d v128 = LoadU(Full128(), p).raw; + return Vec256{ + _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; +#else + return Vec256{ + _mm256_broadcast_pd(reinterpret_cast(p))}; +#endif +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { + _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Store(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT p) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); +} +HWY_API void StoreU(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT p) { + _mm256_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT p) { + _mm256_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi64(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, + Full256 /* tag */, float* HWY_RESTRICT p) { + _mm256_mask_storeu_ps(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, + Full256 /* tag */, double* HWY_RESTRICT p) { + _mm256_mask_storeu_pd(p, m.raw, v.raw); +} + +#else // AVX2 + +// Intel SDM says "No AC# reported for any mask bit combinations". However, AMD +// allows AC# if "Alignment checking enabled and: 256-bit memory operand not +// 32-byte aligned". Fortunately AC# is not enabled by default and requires both +// OS support (CR0) and the application to set rflags.AC. We assume these remain +// disabled because x86/x64 code and compiler output often contain misaligned +// scalar accesses, which would also fault. +// +// Caveat: these are slow on AMD Jaguar/Bulldozer. + +template * = nullptr> +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT p) { + // There is no maskload_epi8/16. Blending is also unsafe because loading a + // full vector that crosses the array end causes asan faults. Resort to scalar + // code; the caller should instead use memcpy, assuming m is FirstN(d, n). + const RebindToUnsigned du; + using TU = TFromD; + alignas(32) TU buf[32 / sizeof(T)]; + alignas(32) TU mask[32 / sizeof(T)]; + Store(BitCast(du, v), du, buf); + Store(BitCast(du, VecFromMask(d, m)), du, mask); + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi64(pi, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_ps(p, mi.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, + Full256 d, double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_pd(p, mi.raw, v.raw); +} + +#endif + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(Vec256 v, Full256 /* tag */, + T* HWY_RESTRICT aligned) { + _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i32scatter_ps(base, index.raw, v.raw, 4); +} + +HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i64scatter_pd(base, index.raw, v.raw, 8); +} + +#else + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Full256(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<4> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<4> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<8> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<8> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec256 GatherOffset(Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec256 GatherIndex(Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +HWY_API Vec256 GatherOffset(Full256 /* tag */, + const float* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i32gather_ps(base, offset.raw, 1)}; +} +HWY_API Vec256 GatherIndex(Full256 /* tag */, + const float* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i32gather_ps(base, index.raw, 4)}; +} + +HWY_API Vec256 GatherOffset(Full256 /* tag */, + const double* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i64gather_pd(base, offset.raw, 1)}; +} +HWY_API Vec256 GatherIndex(Full256 /* tag */, + const double* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i64gather_pd(base, index.raw, 8)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_castsi256_si128(v.raw)}; +} +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_castps256_ps128(v.raw)}; +} +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_castpd256_pd128(v.raw)}; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return LowerHalf(Full128(), v); +} + +// ------------------------------ UpperHalf + +template +HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_extracti128_si256(v.raw, 1)}; +} +HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_extractf128_ps(v.raw, 1)}; +} +HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_extractf128_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + const Full256 d; + HWY_DASSERT(i < Lanes(d)); + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + const Full256 d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (VEX encoded +// 128-bit instructions zero the upper lanes to avoid large penalties), a +// compiler could decide to optimize out code that relies on this. +// +// The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the +// zeroing, but it is not available on MSVC until 15.7 nor GCC until 10.1. For +// older GCC, we can still obtain the desired code thanks to pattern +// recognition; note that the expensive insert instruction is not actually +// generated, see https://gcc.godbolt.org/z/1MKGaP. + +#if !defined(HWY_HAVE_ZEXT) +#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC >= 1915) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) +#define HWY_HAVE_ZEXT 1 +#else +#define HWY_HAVE_ZEXT 0 +#endif +#endif // defined(HWY_HAVE_ZEXT) + +template +HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { +#if HWY_HAVE_ZEXT +return Vec256{_mm256_zextsi128_si256(lo.raw)}; +#else + return Vec256{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; +#endif +} +HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, + Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextps128_ps256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; +#endif +} +HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, + Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextpd128_pd256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ Combine + +template +HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; +} +HWY_API Vec256 Combine(Full256 d, Vec128 hi, + Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; +} +HWY_API Vec256 Combine(Full256 d, Vec128 hi, + Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bslli_epi128. + return Vec256{_mm256_slli_si256(v.raw, kBytes)}; +} + +template +HWY_API Vec256 ShiftLeftBytes(const Vec256 v) { + return ShiftLeftBytes(Full256(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(Full256(), v); +} + +// ------------------------------ ShiftRightBytes + +template +HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bsrli_epi128. + return Vec256{_mm256_srli_si256(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ CombineShiftRightBytes + +// Extracts 128 bits from by skipping the least-significant kBytes. +template > +HWY_API V CombineShiftRightBytes(Full256 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec256{_mm256_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec256 Broadcast(Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 Shuffle2301(const Vec256 a, const Vec256 b) { + const Full256 d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 Shuffle1230(const Vec256 a, const Vec256 b) { + const Full256 d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 Shuffle3012(const Vec256 a, const Vec256 b) { + const Full256 d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + // Shorter encoding than _mm256_permute_ps. + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + // Shorter encoding than _mm256_permute_pd. + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; +} + +// Rotate right 32 bits +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices256 { + __m256i raw; +}; + +// Native 8x32 instruction: indices remain unchanged +template +HWY_API Indices256 IndicesFromVec(Full256 /* tag */, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(32 / sizeof(T)))))); +#endif + return Indices256{vec.raw}; +} + +// 64-bit lanes: convert indices to 8x32 unless AVX3 is available +template +HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 idx64) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + const Rebind di; + (void)di; // potentially unused +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && + AllTrue(di, Lt(idx64, Set(di, static_cast(32 / sizeof(T)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Indices256{idx64.raw}; +#else + const Repartition df; // 32-bit! + // Replicate 64-bit index into upper 32 bits + const Vec256 dup = + BitCast(di, Vec256{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); + // For each idx64 i, idx32 are 2*i and 2*i+1. + const Vec256 idx32 = dup + dup + Set(di, TI(1) << 32); + return Indices256{idx32.raw}; +#endif +} + +template +HWY_API Indices256 SetTableIndices(const Full256 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_epi64(idx.raw, v.raw)}; +#else + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +#endif +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_pd(idx.raw, v.raw)}; +#else + const Full256 df; + const Full256 du; + return BitCast(df, Vec256{_mm256_permutevar8x32_epi32( + BitCast(du, v).raw, idx.raw)}); +#endif +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute2x128_si256(v.raw, v.raw, 0x01)}; +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute2f128_ps(v.raw, v.raw, 0x01)}; +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute2f128_pd(v.raw, v.raw, 0x01)}; +} + +// ------------------------------ Reverse (RotateRight) + +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + alignas(32) constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + alignas(32) constexpr int64_t kReverse[4] = {3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec256 idx = Load(di, kReverse); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide> du32; + const Vec256 rev32 = Reverse(du32, BitCast(du32, v)); + return BitCast(d, RotateRight<16>(rev32)); +#endif +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { + const Full256 du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 (SwapAdjacentBlocks) + +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse4[16] = {3, 2, 1, 0, 7, 6, 5, 4, + 11, 10, 9, 8, 15, 14, 13, 12}; + const Vec256 idx = Load(di, kReverse4); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { + // Could also use _mm256_permute4x64_epi64. + return SwapAdjacentBlocks(Shuffle01(v)); +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec256 idx = Load(di, kReverse8); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + return Reverse(d, v); +} + +template +HWY_API Vec256 Reverse8(Full256 /* tag */, const Vec256 /* v */) { + HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template > +HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template > +HWY_API Vec256 ZipLower(Vec256 a, Vec256 b) { + return BitCast(Full256(), InterleaveLower(a, b)); +} +template > +HWY_API Vec256 ZipLower(Full256 dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveLower(a, b)); +} + +template > +HWY_API Vec256 ZipUpper(Full256 dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveUpper(Full256(), a, b)); +} + +// ------------------------------ Blocks (LowerHalf, ZeroExtendVector) + +// _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. +// _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no +// extra cost) for LowerLower and UpperLower. + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half d2; + return Vec256{_mm256_inserti128_si256(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +HWY_API Vec256 ConcatLowerLower(Full256 d, + const Vec256 hi, + const Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; +} +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; +} +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; +} +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; +} +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x31)}; +} +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; +} +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; +} + +// ------------------------------ ConcatOdd + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(32) constexpr uint8_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 8-bit shift so we can pack. + const Vec256 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<8>(BitCast(dw, lo)); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint16_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 16-bit shift so we can pack. + const Vec256 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<16>(BitCast(dw, lo)); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v3131{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return Vec256{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, + Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + const Vec256 v3131{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v31{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; + return Vec256{ + _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, + Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + (void)d; + const Vec256 v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; + return Vec256{ + _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ ConcatEven + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec256 mask = Set(dw, 0x00FF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint16_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 16 bits per u32 so we can pack. + const Vec256 mask = Set(dw, 0x0000FFFF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v2020{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return Vec256{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, + Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + const Vec256 v2020{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + +#endif +} + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v20{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; + return Vec256{ + _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, + Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + (void)d; + const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; + return Vec256{ + _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec256 DupEven(const Vec256 v) { + return InterleaveLower(Full256(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec256 DupOdd(const Vec256 v) { + return InterleaveUpper(Full256(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, + const Vec256 b) { + const Full256 d; + const Full256 d8; + alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi16(a.raw, b.raw, 0x55)}; +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; +} + +} // namespace detail + +template +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; +} + +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; +} + +// ------------------------------ OddEvenBlocks + +template +Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_epi32(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; +} + +// ------------------------------ ReverseBlocks (ConcatLowerUpper) + +template +HWY_API Vec256 ReverseBlocks(Full256 d, Vec256 v) { + return ConcatLowerUpper(d, v, v); +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, + const Vec256 from) { + return Vec256{_mm256_shuffle_epi8(bytes.raw, from.raw)}; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(const Vec256 bytes, + const Vec128 from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(Full128(), tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(const Vec128 bytes, + const Vec256 from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by x86_128. + +// ------------------------------ Shl (Mul, ZipLower) + +namespace detail { + +#if HWY_TARGET > HWY_AVX3 && !HWY_IDE // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec256> Pow2(const Vec256 v) { + static_assert(sizeof(T) == 2, "Only for 16-bit"); + const Full256 d; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // Do not use ConvertTo because it checks for overflow, which is redundant + // because we only care about v in [0, 16). + const Vec256 bits0{_mm256_cvttps_epi32(BitCast(df, f0).raw)}; + const Vec256 bits1{_mm256_cvttps_epi32(BitCast(df, f1).raw)}; + return Vec256>{_mm256_packus_epi32(bits0.raw, bits1.raw)}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; +} + +template +HWY_INLINE Vec256 Shl(hwy::SignedTag /*tag*/, Vec256 v, Vec256 bits) { + // Signed left shifts are the same as unsigned. + const Full256 di; + const Full256> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec256 operator<<(Vec256 v, Vec256 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; +#else + Full256 d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + auto out = MulHigh(v, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), v, out); +#endif +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256(), v, bits); +#endif +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256(), v, bits); +#endif +} + +HWY_INLINE Vec256 MulEven(const Vec256 a, + const Vec256 b) { + const Full256 du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveLower(mulL, mulH); +} + +HWY_INLINE Vec256 MulOdd(const Vec256 a, + const Vec256 b) { + const Full256 du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need bits [95:64] (= upper half of input) + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Same as above, but we're using the odd results (upper 64 bits per block). + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveUpper(du64, mulL, mulH); +} + +// ------------------------------ ReorderWidenMulAccumulate +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 /*d32*/, + Vec256 a, + Vec256 b, + const Vec256 sum0, + Vec256& /*sum1*/) { + return sum0 + Vec256{_mm256_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +HWY_API Vec256 RearrangeToOddPlusEven(const Vec256 sum0, + Vec256 /*sum1*/) { + return sum0; // invariant already holds +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtps_pd(v.raw)}; +} + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtepi32_pd(v.raw)}; +} + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi32_epi64(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); + // Concatenating lower halves of both 128-bit blocks afterward is more + // efficient than an extra input with low block = high block of v. + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i u16_concat = _mm256_permute4x64_epi64(u16_blocks, 0x88); + const __m128i u16 = _mm256_castsi256_si128(u16_concat); + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; +} + + // Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". + // 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") + +HWY_API Vec128 DemoteTo(Full128 df16, + const Vec256 v) { +#ifdef HWY_DISABLE_F16C + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +HWY_API Vec128 DemoteTo(Full128 dbf16, + const Vec256 v) { + // TODO(janwas): _mm256_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec256 ReorderDemote2To(Full256 dbf16, + Vec256 a, Vec256 b) { + // TODO(janwas): _mm256_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec256 ReorderDemote2To(Full256 /*d16*/, + Vec256 a, Vec256 b) { + return Vec256{_mm256_packs_epi32(a.raw, b.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{_mm256_cvtpd_ps(v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const auto clamped = detail::ClampF64ToI32Max(Full256(), v); + return Vec128{_mm256_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec256 v) { + const Full256 d32; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Full128(), quad); + const auto pair = LowerHalf(lo | hi); + return BitCast(Full64(), pair); +} + +// ------------------------------ Truncations + +namespace detail { + +// LO and HI each hold four indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatHalves(Vec256 v) { + const Full256 d32; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint32_t kMap[8] = { + LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); +#else + alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, + ~0u, ~0u, LO, HI}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); + // Possible alternative: + // const auto lo = LowerHalf(quad); + // const auto hi = UpperHalf(Full128(), quad); + // const auto result = lo | hi; +#endif + + return Vec128{_mm256_castsi256_si128(result)}; +} + +// LO and HI each hold two indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatQuarters(Vec256 v) { + const Full256 d16; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint16_t kMap[16] = { + LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d16, kMap).raw); + return LowerHalf(Vec128{_mm256_castsi256_si128(result)}); +#else + constexpr uint16_t ff = static_cast(~0u); + alignas(32) static constexpr uint16_t kMap[16] = { + LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; + const auto quad = TableLookupBytes(v, Load(d16, kMap)); + const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); + const auto half = _mm256_castsi256_si128(mixed); + return LowerHalf(Vec128{_mm_packus_epi32(half, half)}); +#endif +} + +} // namespace detail + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const Full256 d32; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); + return LowerHalf(LowerHalf(LowerHalf(Vec256{result}))); +#else + alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, + 0x0800FFFFu, ~0u, ~0u, ~0u}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Full128(), quad); + const auto result = lo | hi; + return LowerHalf(LowerHalf(Vec128{result.raw})); +#endif +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); + return Vec128{result.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const Full256 d32; + alignas(32) constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto v32 = + TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); + return LowerHalf(Vec256{v32.raw}); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); + return Vec128{full.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); + return Vec128{full.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); + return Vec128{full.raw}; +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{_mm256_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec256 ConvertTo(Full256 dd, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec256{_mm256_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +HWY_API Vec256 ConvertTo(HWY_MAYBE_UNUSED Full256 df, + const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +HWY_API Vec256 ConvertTo(HWY_MAYBE_UNUSED Full256 dd, + const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFFULL); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double256_fast = [&dd](Vec256 w) HWY_ATTR { + w = Or(w, Vec256{ + detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double256_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double256_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { + return detail::FixConversionOverflow(d, v, _mm256_cvttps_epi32(v.raw)); +} + +HWY_API Vec256 ConvertTo(Full256 di, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::FixConversionOverflow(di, v, _mm256_cvttpd_epi64(v.raw)); +#else + using VI = decltype(Zero(di)); + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} + +HWY_API Vec256 NearestInt(const Vec256 v) { + const Full256 di; + return detail::FixConversionOverflow(di, v, _mm256_cvtps_epi32(v.raw)); +} + + +HWY_API Vec256 PromoteTo(Full256 df32, + const Vec128 v) { +#ifdef HWY_DISABLE_F16C + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec256{_mm256_cvtph_ps(v.raw)}; +#endif +} + +HWY_API Vec256 PromoteTo(Full256 df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec256 AESRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESLastRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulLower(LowerHalf(a), LowerHalf(b))); +#endif +} + +HWY_API Vec256 CLMulUpper(Vec256 a, Vec256 b) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulUpper(LowerHalf(a), LowerHalf(b))); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +HWY_API Vec256 Iota(const Full256 d, const T2 first) { + HWY_ALIGN T lanes[32 / sizeof(T)]; + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return Load(d, lanes); +} + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadMaskBits + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask256 LoadMaskBits(const Full256 /* tag */, + const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask256::FromBits(mask_bits); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, + uint8_t* bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = static_cast((1ull << N) - 1); + bits[0] = static_cast(bits[0] & mask_bits); + } + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 d, const Mask256 mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + return (uint64_t{mask.raw} & 0xF) == 0; +} + +} // namespace detail + +template +HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { + return detail::AllFalse(hwy::SizeTag(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + // Cannot use _kortestc because we have less than 8 mask bits. + return mask.raw == 0xFu; +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { + return detail::AllTrue(hwy::SizeTag(), mask); +} + +// ------------------------------ Compress + +// 16-bit is defined in x86_512 so we can use 512-bit vectors. + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintCompress64x4NibbleTables + 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, + 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, + 0x00001032, 0x00001320, 0x00000321, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const Full256 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot (Compress) + +// Implemented in x86_512 for lane size != 8. + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintCompressNot64x4NibbleTables + 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, + 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, + 0x00003210, 0x00003201, 0x00003210, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const Full256 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(32) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressStore + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, + T* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, + T* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, + Full256 /* tag */, + float* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, + Full256 /* tag */, + double* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) + +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) > 2) { + // Native (32 or 64-bit) AVX-512 instruction already does the blending at no + // extra cost (latency 11, rthroughput 2 - same as compress plus store). + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + Full256 d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_LE128 there. +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + const Repartition du32; + const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + const Repartition du64; + alignas(32) constexpr uint64_t kRep8[4] = { + 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, + 0x0303030303030303ull}; + const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); + + alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint16_t kBit[16] = { + 1, 2, 4, 8, 16, 32, 64, 128, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint64_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits256(d, mask_bits); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + const Full256 d; + const Full256 d8; + const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; + // Prevent sign-extension of 32-bit masks because the intrinsic returns int. + return static_cast(_mm256_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { +#if HWY_ARCH_X86_64 + const Full256 d; + const Full256 d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + const uint64_t sign_bits8 = BitsFromMask(mask8); + // Skip the bits from the lower byte of each u16 (better not to use the + // same packs_epi16 as SSE4, because that requires an extra swizzle here). + return _pext_u64(sign_bits8, 0xAAAAAAAAull); +#else + // Slow workaround for 32-bit builds, which lack _pext_u64. + // Remove useless lower half of each u16 while preserving the sign bit. + // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. + const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); + // Move odd qwords (value zero) to top so they don't affect the mask value. + const auto compressed = + _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0)); + return static_cast(_mm256_movemask_epi8(compressed)); +#endif // HWY_ARCH_X86_64 +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + const Full256 d; + const Full256 df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_ps(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + const Full256 d; + const Full256 df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_pd(sign_bits)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, + uint8_t* bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask +// lane is 0 or ~0. +template +HWY_API bool AllFalse(const Full256 d, const Mask256 mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return detail::BitsFromMask(mask8) == 0; +} + +template +HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Full256 d, const Mask256 mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return detail::BitsFromMask(mask8) == (1ull << 32) - 1; +} +template +HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { + constexpr uint64_t kAllBits = (1ull << (32 / sizeof(T))) - 1; + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Full256 d, const Mask256 mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return PopCount(detail::BitsFromMask(mask8)) >> 1; +} +template +HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +template +HWY_INLINE Vec256 IndicesFromBits(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintCompress32x8Tables + 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, + 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, + 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, + 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, + 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, + 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, + 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, + 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, + 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, + 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, + 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, + 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, + 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, + 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, + 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, + 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, + 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, + 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, + 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, + 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, + 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, + 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, + 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, + 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, + 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, + 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, + 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, + 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, + 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, + 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, + 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, + 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, + 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, + 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, + 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, + 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, + 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, + 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, + 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, + 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, + 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, + 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, + 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; + + // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromBits(Full256 d, uint64_t mask_bits) { + const Repartition d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t u32_indices[128] = { + // PrintCompress64x4PairTables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, + 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, + 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, + 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, + 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, + 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, + 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, + 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits(Full256 d, + uint64_t mask_bits) { + const RebindToUnsigned d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintCompressNot32x8Tables + 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, + 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, + 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, + 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, + 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, + 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, + 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, + 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, + 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, + 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, + 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, + 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, + 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, + 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, + 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, + 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, + 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, + 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, + 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, + 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, + 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, + 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, + 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, + 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, + 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, + 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, + 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, + 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, + 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, + 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, + 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, + 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, + 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, + 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, + 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, + 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, + 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, + 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, + 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, + 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, + 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, + 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, + 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; + + // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits(Full256 d, + uint64_t mask_bits) { + const Repartition d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t u32_indices[128] = { + // PrintCompressNot64x4PairTables + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, + 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, + 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, + 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, + 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const Full256 d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromBits(d, mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Half duh; + const auto half0 = LowerHalf(duh, vu16); + const auto half1 = UpperHalf(duh, vu16); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = detail::CompressBits(half0, mask_bits0); + const auto compressed1 = detail::CompressBits(half1, mask_bits1); + + alignas(32) uint16_t all_true[16] = {}; + // Store mask=true lanes, left to right. + const size_t num_true0 = PopCount(mask_bits0); + Store(compressed0, duh, all_true); + StoreU(compressed1, duh, all_true + num_true0); + + if (hwy::HWY_NAMESPACE::CompressIsPartition::value) { + // Store mask=false lanes, right to left. The second vector fills the upper + // half with right-aligned false lanes. The first vector is shifted + // rightwards to overwrite the true lanes of the second. + alignas(32) uint16_t all_false[16] = {}; + const size_t num_true1 = PopCount(mask_bits1); + Store(compressed1, duh, all_false + 8); + StoreU(compressed0, duh, all_false + num_true1); + + const auto mask = FirstN(du, num_true0 + num_true1); + return BitCast(d, + IfThenElse(mask, Load(du, all_true), Load(du, all_false))); + } else { + // Only care about the mask=true lanes. + return BitCast(d, Load(du, all_true)); + } +} + +template // 4 or 8 bytes +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + const Full256 d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromNotBits(d, mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + // Compress ensures only the lower 16 bits are set, so flip those. + return Compress(v, mask_bits ^ 0xFFFF); +} + +} // namespace detail + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 m) { + return detail::Compress(v, detail::BitsFromMask(m)); +} + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 m) { + return detail::CompressNot(v, detail::BitsFromMask(m)); +} + +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + return CompressNot(v, mask); +} + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template // 4 or 8 bytes +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + + const Repartition du32; + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). Nibble MSB encodes FirstN. + const Vec256 idx_and_mask = detail::IndicesFromBits(d, mask_bits); + // Shift nibble MSB into MSB + const Mask256 mask32 = MaskFromVec(ShiftLeft<28>(idx_and_mask)); + // First cast to unsigned (RebindMask cannot change lane size) + const Mask256> mask_u{mask32.raw}; + const Mask256 mask = RebindMask(d, mask_u); + const Vec256 compressed = + BitCast(d, TableLookupLanes(BitCast(du32, v), + Indices256{idx_and_mask.raw})); + + BlendedStore(compressed, mask, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Vec256 compressed = detail::Compress(v, mask_bits); + +#if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(32) T buf[16]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +template +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + Full256 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C, + Vec256& D) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + const Vec256 v76 = LoadU(d, unaligned + 3 * N); + + A = ConcatLowerLower(d, v54, v10); + B = ConcatUpperUpper(d, v54, v10); + C = ConcatLowerLower(d, v76, v32); + D = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(const Vec256 i, const Vec256 j, + const Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(const Vec256 i, const Vec256 j, + const Vec256 k, Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(const Vec256 i, const Vec256 j, + const Vec256 k, const Vec256 l, + Full256 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ Reductions + +namespace detail { + +// Returns sum{lane[i]} in each lane. "v3210" is a replicated 128-bit block. +// Same logic as x86/128.h, but with Vec256 arguments. +template +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = v3210 + v1032; + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Min(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Max(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +template +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return Max(v10, v01); +} + +HWY_API Vec256 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +HWY_API Vec256 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for {uif}{32,64},{ui}16. Returns the broadcasted result. +template +HWY_API Vec256 SumOfLanes(Full256 d, const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::SumOfLanes(hwy::SizeTag(), vLH + vHL); +} +template +HWY_API Vec256 MinOfLanes(Full256 d, const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::MinOfLanes(hwy::SizeTag(), Min(vLH, vHL)); +} +template +HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::MaxOfLanes(hwy::SizeTag(), Max(vLH, vHL)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/highway/hwy/ops/x86_512-inl.h b/third_party/highway/hwy/ops/x86_512-inl.h new file mode 100644 index 0000000000..5f3b34c357 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_512-inl.h @@ -0,0 +1,4605 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 512-bit AVX512 vectors and operations. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +// clang-format off +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +#include +#include + +#if HWY_IS_MSAN +#include +#endif + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_256-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { + +template +struct Raw512 { + using type = __m512i; +}; +template <> +struct Raw512 { + using type = __m512; +}; +template <> +struct Raw512 { + using type = __m512d; +}; + +// Template arg: sizeof(lane type) +template +struct RawMask512 {}; +template <> +struct RawMask512<1> { + using type = __mmask64; +}; +template <> +struct RawMask512<2> { + using type = __mmask32; +}; +template <> +struct RawMask512<4> { + using type = __mmask16; +}; +template <> +struct RawMask512<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +class Vec512 { + using Raw = typename detail::Raw512::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec512& operator*=(const Vec512 other) { + return *this = (*this * other); + } + HWY_INLINE Vec512& operator/=(const Vec512 other) { + return *this = (*this / other); + } + HWY_INLINE Vec512& operator+=(const Vec512 other) { + return *this = (*this + other); + } + HWY_INLINE Vec512& operator-=(const Vec512 other) { + return *this = (*this - other); + } + HWY_INLINE Vec512& operator&=(const Vec512 other) { + return *this = (*this & other); + } + HWY_INLINE Vec512& operator|=(const Vec512 other) { + return *this = (*this | other); + } + HWY_INLINE Vec512& operator^=(const Vec512 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Mask register: one bit per lane. +template +struct Mask512 { + using Raw = typename detail::RawMask512::type; + Raw raw; +}; + +template +using Full512 = Simd; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } +HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } +HWY_INLINE __m512i BitCastToInteger(__m512d v) { + return _mm512_castpd_si512(v); +} + +template +HWY_INLINE Vec512 BitCastToByte(Vec512 v) { + return Vec512{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger512 { + HWY_INLINE __m512i operator()(__m512i v) { return v; } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } +}; + +template +HWY_INLINE Vec512 BitCastFromByte(Full512 /* tag */, Vec512 v) { + return Vec512{BitCastFromInteger512()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 BitCast(Full512 d, Vec512 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_si512()}; +} +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_ps()}; +} +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec512 Set(Full512 /* tag */, const uint8_t t) { + return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint16_t t) { + return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint32_t t) { + return Vec512{_mm512_set1_epi32(static_cast(t))}; +} +HWY_API Vec512 Set(Full512 /* tag */, const uint64_t t) { + return Vec512{ + _mm512_set1_epi64(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int8_t t) { + return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int16_t t) { + return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int32_t t) { + return Vec512{_mm512_set1_epi32(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const int64_t t) { + return Vec512{ + _mm512_set1_epi64(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const float t) { + return Vec512{_mm512_set1_ps(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const double t) { + return Vec512{_mm512_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec512 Undefined(Full512 /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec512{_mm512_undefined_epi32()}; +} +HWY_API Vec512 Undefined(Full512 /* tag */) { + return Vec512{_mm512_undefined_ps()}; +} +HWY_API Vec512 Undefined(Full512 /* tag */) { + return Vec512{_mm512_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec512 Not(const Vec512 v) { + using TU = MakeUnsigned; + const __m512i vu = BitCast(Full512(), v).raw; + return BitCast(Full512(), + Vec512{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); +} + +// ------------------------------ And + +template +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_ps(a.raw, b.raw)}; +} +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { + return Vec512{_mm512_andnot_si512(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor3 +template +HWY_API Vec512 Xor3(Vec512 x1, Vec512 x2, Vec512 x3) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +} + +// ------------------------------ Or3 +template +HWY_API Vec512 Or3(Vec512 o1, Vec512 o2, Vec512 o3) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +} + +// ------------------------------ OrAnd +template +HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { + return And(a, b); +} + +template +HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { + return Or(a, b); +} + +template +HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<1> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<2> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<4> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<8> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 PopulationCount(Vec512 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ CopySign + +template +HWY_API Vec512 CopySign(const Vec512 magn, const Vec512 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Full512 d; + const auto msb = SignBit(d); + + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m512i out = _mm512_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +} + +template +HWY_API Vec512 CopySignToAbs(const Vec512 abs, const Vec512 sign) { + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== MASK + +// ------------------------------ FirstN + +// Possibilities for constructing a bitmask of N ones: +// - kshift* only consider the lowest byte of the shift count, so they would +// not correctly handle large n. +// - Scalar shifts >= 64 are UB. +// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, +// we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. + +#if HWY_ARCH_X86_32 +namespace detail { + +// 32 bit mask is sufficient for lane size >= 2. +template +HWY_INLINE Mask512 FirstN(size_t n) { + Mask512 m; + const uint32_t all = ~uint32_t{0}; + // BZHI only looks at the lower 8 bits of n! + m.raw = static_cast((n > 255) ? all : _bzhi_u32(all, n)); + return m; +} + +template +HWY_INLINE Mask512 FirstN(size_t n) { + const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; + return Mask512{static_cast<__mmask64>(bits)}; +} + +} // namespace detail +#endif // HWY_ARCH_X86_32 + +template +HWY_API Mask512 FirstN(const Full512 /*tag*/, size_t n) { +#if HWY_ARCH_X86_64 + Mask512 m; + const uint64_t all = ~uint64_t{0}; + // BZHI only looks at the lower 8 bits of n! + m.raw = static_cast((n > 255) ? all : _bzhi_u64(all, n)); + return m; +#else + return detail::FirstN(n); +#endif // HWY_ARCH_X86_64 +} + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +HWY_API Vec512 IfThenElse(const Mask512 mask, + const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElse(const Mask512 mask, + const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec512 IfThenElseZero(const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElseZero(const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, + const Vec512 no) { + return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, + const Vec512 no) { + return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { + static_assert(IsSigned(), "Only works for signed/float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec512 ZeroIfNegative(const Vec512 v) { + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_add_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec512 SumsOf8(const Vec512 v) { + return Vec512{_mm512_sad_epu8(v.raw, _mm512_setzero_si512())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec512 AverageRound(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 AverageRound(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec512 Abs(const Vec512 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (untested due to internal compiler error) + const auto zero = Zero(Full512()); + return Vec512{_mm512_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec512{_mm512_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi16(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi32(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi64(v.raw)}; +} + +// These aren't native instructions, they also involve AND with constant. +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_ps(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_pd(v.raw)}; +} +// ------------------------------ ShiftLeft + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const Full512 d8; + // Use raw instead of BitCast to support N=1. + const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const Full512 di; + const Full512 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + return Vec512{_mm512_ror_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + return Vec512{_mm512_ror_epi64(v.raw, kBits)}; +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const Full512 di; + const Full512 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Shl + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { + const Full512 di; + const Full512> du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; +} + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// Signed +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512 MulFixedPoint15(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ Neg (Sub) + +template +HWY_API Vec512 Neg(const Vec512 v) { + return Xor(v, SignBit(Full512())); +} + +template +HWY_API Vec512 Neg(const Vec512 v) { + return Zero(Full512()) - v; +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_div_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator/(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp14_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec512 AbsDiff(const Vec512 a, const Vec512 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 NegMulAdd(const Vec512 mul, + const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 NegMulSub(const Vec512 mul, + const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ps(v.raw)}; +} +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec512 ApproximateReciprocalSqrt(const Vec512 v) { + return Vec512{_mm512_rsqrt14_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Toward nearest integer, tie to even +HWY_API Vec512 Round(const Vec512 v) { + return Vec512{_mm512_roundscale_ps( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Round(const Vec512 v) { + return Vec512{_mm512_roundscale_pd( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec512 Trunc(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Trunc(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec512 Ceil(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Ceil(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec512 Floor(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Floor(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask512 RebindMask(Full512 /*tag*/, Mask512 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask512{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<1> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<2> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<4> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<8> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask512 operator<(Vec512 a, Vec512 b) { + return b > a; +} + +template +HWY_API Mask512 operator<=(Vec512 a, Vec512 b) { + return b >= a; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; +} +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi8(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi8(v.raw)}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi16(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi16(v.raw)}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; +} + +template +HWY_API Vec512 VecFromMask(Full512 /* tag */, const Mask512 v) { + return VecFromMask(v); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask512 Not(hwy::SizeTag<1> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask64(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<2> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask32(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<4> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask16(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<8> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask8(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFF)}; +#endif +} + +template +HWY_INLINE Mask512 And(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask64(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Or(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask64(a.raw, b.raw)}; +#else + return Mask512{~(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask512 Not(const Mask512 m) { + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask512 And(const Mask512 a, Mask512 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 AndNot(const Mask512 a, Mask512 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Or(const Mask512 a, Mask512 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Xor(const Mask512 a, Mask512 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 ExclusiveNeither(const Mask512 a, Mask512 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return VecFromMask(v < Zero(Full512())); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, 63)}; +} + +// ------------------------------ Floating-point classification (Not) + +HWY_API Mask512 IsNaN(const Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask(v.raw, 0x81)}; +} +HWY_API Mask512 IsNaN(const Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask(v.raw, 0x81)}; +} + +HWY_API Mask512 IsInf(const Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask(v.raw, 0x18)}; +} +HWY_API Mask512 IsInf(const Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512 IsFinite(const Vec512 v) { + return Not(Mask512{_mm512_fpclass_ps_mask(v.raw, 0x99)}); +} +HWY_API Mask512 IsFinite(const Vec512 v) { + return Not(Mask512{_mm512_fpclass_pd_mask(v.raw, 0x99)}); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec512 Load(Full512 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_si512(aligned)}; +} +HWY_API Vec512 Load(Full512 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ps(aligned)}; +} +HWY_API Vec512 Load(Full512 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_pd(aligned)}; +} + +template +HWY_API Vec512 LoadU(Full512 /* tag */, const T* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_si512(p)}; +} +HWY_API Vec512 LoadU(Full512 /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ps(p)}; +} +HWY_API Vec512 LoadU(Full512 /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi64(m.raw, p)}; +} + +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_ps(m.raw, p)}; +} + +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_pd(m.raw, p)}; +} + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const T* const HWY_RESTRICT p) { + const auto x4 = LoadU(Full128(), p); + return Vec512{_mm512_broadcast_i32x4(x4.raw)}; +} +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const float* const HWY_RESTRICT p) { + const __m128 x4 = _mm_loadu_ps(p); + return Vec512{_mm512_broadcast_f32x4(x4)}; +} + +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const double* const HWY_RESTRICT p) { + const __m128d x2 = _mm_loadu_pd(p); + return Vec512{_mm512_broadcast_f64x2(x2)}; +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Store(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT aligned) { + _mm512_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); +} +HWY_API void StoreU(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT p) { + _mm512_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec512 v, Full512, + double* HWY_RESTRICT p) { + _mm512_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi64(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec512 v, Mask512 m, + Full512 /* tag */, float* HWY_RESTRICT p) { + _mm512_mask_storeu_ps(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec512 v, Mask512 m, + Full512 /* tag */, double* HWY_RESTRICT p) { + _mm512_mask_storeu_pd(p, m.raw, v.raw); +} + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec512 v, Full512, + double* HWY_RESTRICT aligned) { + _mm512_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec512 v, Full512 d, T* HWY_RESTRICT base, + const Vec512 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec512 v, Full512 d, T* HWY_RESTRICT base, + const Vec512 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i32scatter_ps(base, index.raw, v.raw, 4); +} + +HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<4> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i32gather_epi32(offset.raw, base, 1)}; +} +template +HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<4> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i32gather_epi32(index.raw, base, 4)}; +} + +template +HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<8> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i64gather_epi64(offset.raw, base, 1)}; +} +template +HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<8> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i64gather_epi64(index.raw, base, 8)}; +} + +} // namespace detail + +template +HWY_API Vec512 GatherOffset(Full512 d, const T* HWY_RESTRICT base, + const Vec512 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec512 GatherIndex(Full512 d, const T* HWY_RESTRICT base, + const Vec512 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +HWY_API Vec512 GatherOffset(Full512 /* tag */, + const float* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i32gather_ps(offset.raw, base, 1)}; +} +HWY_API Vec512 GatherIndex(Full512 /* tag */, + const float* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i32gather_ps(index.raw, base, 4)}; +} + +HWY_API Vec512 GatherOffset(Full512 /* tag */, + const double* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i64gather_pd(offset.raw, base, 1)}; +} +HWY_API Vec512 GatherIndex(Full512 /* tag */, + const double* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i64gather_pd(index.raw, base, 8)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_castsi512_si256(v.raw)}; +} +HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_castps512_ps256(v.raw)}; +} +HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_castpd512_pd256(v.raw)}; +} + +template +HWY_API Vec256 LowerHalf(Vec512 v) { + return LowerHalf(Full256(), v); +} + +// ------------------------------ UpperHalf + +template +HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_extracti32x8_epi32(v.raw, 1)}; +} +HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_extractf32x8_ps(v.raw, 1)}; +} +HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_extractf64x4_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec512 v, size_t i) { + const Full512 d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec512 InsertLane(const Vec512 v, size_t i, T t) { + const Full512 d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec512 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { +#if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. + return Vec512{_mm512_zextsi256_si512(lo.raw)}; +#else + return Vec512{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; +#endif +} +HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, + Vec256 lo) { +#if HWY_HAVE_ZEXT + return Vec512{_mm512_zextps256_ps512(lo.raw)}; +#else + return Vec512{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; +#endif +} +HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, + Vec256 lo) { +#if HWY_HAVE_ZEXT + return Vec512{_mm512_zextpd256_pd512(lo.raw)}; +#else + return Vec512{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ Combine + +template +HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; +} +HWY_API Vec512 Combine(Full512 d, Vec256 hi, + Vec256 lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; +} +HWY_API Vec512 Combine(Full512 d, Vec256 hi, + Vec256 lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec512 ShiftLeftBytes(Full512 /* tag */, const Vec512 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512{_mm512_bslli_epi128(v.raw, kBytes)}; +} + +template +HWY_API Vec512 ShiftLeftBytes(const Vec512 v) { + return ShiftLeftBytes(Full512(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec512 ShiftLeftLanes(Full512 d, const Vec512 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec512 ShiftLeftLanes(const Vec512 v) { + return ShiftLeftLanes(Full512(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec512 ShiftRightBytes(Full512 /* tag */, const Vec512 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512{_mm512_bsrli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec512 ShiftRightLanes(Full512 d, const Vec512 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full512 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec512{_mm512_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Signed +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Float +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +namespace detail { + +template +HWY_API Vec512 Shuffle2301(const Vec512 a, const Vec512 b) { + const Full512 d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_CDAB)}); +} +template +HWY_API Vec512 Shuffle1230(const Vec512 a, const Vec512 b) { + const Full512 d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_BCDA)}); +} +template +HWY_API Vec512 Shuffle3012(const Vec512 a, const Vec512 b) { + const Full512 d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_DABC)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + // Shorter encoding than _mm512_permute_ps. + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + // Shorter encoding than _mm512_permute_pd. + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; +} + +// Rotate right 32 bits +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; +} +// Rotate left 32 bits +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; +} + +// Reverse +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices512 { + __m512i raw; +}; + +template +HWY_API Indices512 IndicesFromVec(Full512 /* tag */, Vec512 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full512 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(64 / sizeof(T)))))); +#endif + return Indices512{vec.raw}; +} + +template +HWY_API Indices512 SetTableIndices(const Full512 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi64(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, + Indices512 idx) { + return Vec512{_mm512_permutexvar_pd(idx.raw, v.raw)}; +} + +// ------------------------------ Reverse + +template +HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512 idx = Load(di, kReverse); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { + alignas(64) constexpr int32_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { + alignas(64) constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec512 Reverse2(Full512 d, const Vec512 v) { + const Full512 du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec512 Reverse4(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse4[32] = { + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, + 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; + const Vec512 idx = Load(di, kReverse4); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { + return Vec512{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +HWY_API Vec512 Reverse4(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse8[32] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int32_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + return Reverse(d, v); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template > +HWY_API V InterleaveUpper(Full512 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template > +HWY_API Vec512 ZipLower(Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveLower(a, b)); +} +template > +HWY_API Vec512 ZipLower(Full512 /* d */, Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveLower(a, b)); +} + +template > +HWY_API Vec512 ZipUpper(Full512 d, Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveUpper(d, a, b)); +} + +// ------------------------------ Concat* halves + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks + // are efficiently loaded from 32-bit regs. + const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); + return Vec512{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; +} +HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); + return Vec512{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; +} +HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); + return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; +} + +// ------------------------------ ConcatOdd + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[64] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, + 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, + 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, + 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, + 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; + return BitCast(d, + Vec512{_mm512_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec512 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec512 uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint16_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +} + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, hi.raw)}; +} + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +} + +// ------------------------------ ConcatEven + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, + 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, + 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, + 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, + 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; + return BitCast(d, + Vec512{_mm512_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec512 mask = Set(dw, 0x00FF); + const Vec512 uH = And(BitCast(dw, hi), mask); + const Vec512 uL = And(BitCast(dw, lo), mask); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint16_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +} + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, hi.raw)}; +} + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; +} +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; +} + +template +HWY_API Vec512 DupEven(const Vec512 v) { + return InterleaveLower(Full512(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; +} +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; +} + +template +HWY_API Vec512 DupOdd(const Vec512 v) { + return InterleaveUpper(Full512(), v, v); +} + +// ------------------------------ OddEven + +template +HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { + constexpr size_t s = sizeof(T); + constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; + return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); +} + +// ------------------------------ OddEvenBlocks + +template +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{_mm512_mask_blend_epi64(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// ------------------------------ ReverseBlocks + +template +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, + Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { + return Vec512{_mm512_shuffle_epi8(bytes.raw, indices.raw)}; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec512 bytes, Vec128 from) { + const Full512 d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 from_full{from.raw}; + const auto from_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); + const auto tbl_full = TableLookupBytes(bytes, from_512); + // Shrink to 256, then 128, then partial. + return Vec128{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; +} +template +HWY_API Vec256 TableLookupBytes(Vec512 bytes, Vec256 from) { + const auto from_512 = ZeroExtendVector(Full512(), from); + return LowerHalf(Full256(), TableLookupBytes(bytes, from_512)); +} + +// Partial table vector +template +HWY_API Vec512 TableLookupBytes(Vec128 bytes, Vec512 from) { + const Full512 d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 bytes_full{bytes.raw}; + const auto bytes_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); + return TableLookupBytes(bytes_512, from); +} +template +HWY_API Vec512 TableLookupBytes(Vec256 bytes, Vec512 from) { + const auto bytes_512 = ZeroExtendVector(Full512(), bytes); + return TableLookupBytes(bytes_512, from); +} + +// Partial both are handled by x86_128/256. + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi32_epi64(v.raw)}; +} + +// Float +HWY_API Vec512 PromoteTo(Full512 /* tag */, + const Vec256 v) { + return Vec512{_mm512_cvtph_ps(v.raw)}; +} + +HWY_API Vec512 PromoteTo(Full512 df32, + const Vec256 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { + return Vec512{_mm512_cvtps_pd(v.raw)}; +} + +HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { + return Vec512{_mm512_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec512 v) { + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const Vec512 i16{ + _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; + const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512(), kLanes); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec512 v) { + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, + 0, 4, 8, 12, 0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512(), kLanes); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Vec256{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Vec256 DemoteTo(Full256 dbf16, + const Vec512 v) { + // TODO(janwas): _mm512_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec512 ReorderDemote2To(Full512 dbf16, + Vec512 a, Vec512 b) { + // TODO(janwas): _mm512_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec512 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec512 ReorderDemote2To(Full512 /*d16*/, + Vec512 a, Vec512 b) { + return Vec512{_mm512_packs_epi32(a.raw, b.raw)}; +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + return Vec256{_mm512_cvtpd_ps(v.raw)}; +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const auto clamped = detail::ClampF64ToI32Max(Full512(), v); + return Vec256{_mm512_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec512 v) { + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +// ------------------------------ Truncations + +HWY_API Vec128 TruncateTo(Simd d, + const Vec512 v) { +#if HWY_TARGET == HWY_AVX3_DL + (void)d; + const Full512 d8; + alignas(16) static constexpr uint8_t k8From64[16] = { + 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56}; + const Vec512 bytes{ + _mm512_permutexvar_epi8(LoadDup128(d8, k8From64).raw, v.raw)}; + return LowerHalf(LowerHalf(LowerHalf(bytes))); +#else + const Full512 d32; + alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return TruncateTo(d, LowerHalf(even)); +#endif +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec512 v) { + const Full512 d16; + alignas(16) static constexpr uint16_t k16From64[8] = { + 0, 4, 8, 12, 16, 20, 24, 28}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec512 v) { + const Full512 d32; + alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec512 v) { +#if HWY_TARGET == HWY_AVX3_DL + const Full512 d8; + alignas(16) static constexpr uint8_t k8From32[16] = { + 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d8, k8From32).raw, v.raw)}; +#else + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(LowerHalf(bytes)); +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec512 v) { + const Full512 d16; + alignas(64) static constexpr uint16_t k16From32[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; + return LowerHalf(bytes); +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec512 v) { +#if HWY_TARGET == HWY_AVX3_DL + const Full512 d8; + alignas(64) static constexpr uint8_t k8From16[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + const Vec512 bytes{ + _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; +#else + const Full512 d32; + alignas(16) static constexpr uint32_t k16From32[4] = { + 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k16From32)); + alignas(64) static constexpr uint32_t kIndex32[16] = { + 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(bytes); +} + +// ------------------------------ Convert integer <=> floating point + +HWY_API Vec512 ConvertTo(Full512 /* tag */, + const Vec512 v) { + return Vec512{_mm512_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag */, + const Vec512 v) { + return Vec512{_mm512_cvtepi64_pd(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag*/, + const Vec512 v) { + return Vec512{_mm512_cvtepu32_ps(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag*/, + const Vec512 v) { + return Vec512{_mm512_cvtepu64_pd(v.raw)}; +} + +// Truncates (rounds toward zero). +HWY_API Vec512 ConvertTo(Full512 d, const Vec512 v) { + return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); +} +HWY_API Vec512 ConvertTo(Full512 di, const Vec512 v) { + return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); +} + +HWY_API Vec512 NearestInt(const Vec512 v) { + const Full512 di; + return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec512 AESRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 CLMulLower(Vec512 va, Vec512 vb) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const Full512 d; + const Full128 d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const Full512 d; + const Full128 d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec512 Iota(const Full512 d, const T2 first) { + HWY_ALIGN T lanes[64 / sizeof(T)]; + for (size_t i = 0; i < 64 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} + +} // namespace detail + +template +HWY_API bool AllFalse(const Full512 /* tag */, const Mask512 mask) { + return detail::AllFalse(hwy::SizeTag(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFull; +#endif +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Full512 /* tag */, const Mask512 mask) { + return detail::AllTrue(hwy::SizeTag(), mask); +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask512 LoadMaskBits(const Full512 /* tag */, + const uint8_t* HWY_RESTRICT bits) { + Mask512 mask; + CopyBytes<8 / sizeof(T)>(bits, &mask.raw); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return mask; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full512 /* tag */, const Mask512 mask, + uint8_t* bits) { + const size_t kNumBytes = 8 / sizeof(T); + CopyBytes(&mask.raw, bits); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Full512 /* tag */, const Mask512 mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full512 /* tag */, + const Mask512 mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full512 /* tag */, + const Mask512 mask) { + return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(const Full512 d, const Mask512 mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +// ------------------------------ Compress + +// Always implement 8-bit here even if we lack VBMI2 because we can do better +// than generic_ops (8 at a time) via the native 32-bit compress (16 at a time). +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +namespace detail { + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi8(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(const Vec256 v, + const Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi8(mask.raw, v.raw)}; +} +HWY_INLINE Vec512 NativeCompress(const Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi8(mask.raw, v.raw)}; +} + +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi16(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(const Vec256 v, + const Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi16(mask.raw, v.raw)}; +} +HWY_INLINE Vec512 NativeCompress(const Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + Simd /* d */, + uint8_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + Full256 /* d */, + uint8_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + Full512 /* d */, + uint8_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + Simd /* d */, + uint16_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + Full256 /* d */, + uint16_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + Full512 /* d */, + uint16_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_INLINE Vec512 NativeCompress(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} +// We use table-based compress for 64-bit lanes, see CompressIsPartition. + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + Simd /* d */, + uint32_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + Full256 /* d */, + uint32_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + Full512 /* d */, + uint32_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + Simd /* d */, + uint64_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + Full256 /* d */, + uint64_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + Full512 /* d */, + uint64_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} + +// For u8x16 and <= u16x16 we can avoid store+load for Compress because there is +// only a single compressed vector (u32x16). Other EmuCompress are implemented +// after the EmuCompressStore they build upon. +template +HWY_INLINE Vec128 EmuCompress(Vec128 v, + Mask128 mask) { + const Simd d; + const Rebind d32; + const auto v0 = PromoteTo(d32, v); + + const uint64_t mask_bits{mask.raw}; + // Mask type is __mmask16 if v is full 128, else __mmask8. + using M32 = MFromD; + const M32 m0{static_cast(mask_bits)}; + return TruncateTo(d, Compress(v0, m0)); +} + +template +HWY_INLINE Vec128 EmuCompress(Vec128 v, + Mask128 mask) { + const Simd d; + const Rebind di32; + const RebindToUnsigned du32; + const MFromD mask32{static_cast<__mmask8>(mask.raw)}; + // DemoteTo is 2 ops, but likely lower latency than TruncateTo on SKX. + // Only i32 -> u16 is supported, whereas NativeCompress expects u32. + const VFromD v32 = BitCast(du32, PromoteTo(di32, v)); + return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); +} + +HWY_INLINE Vec256 EmuCompress(Vec256 v, + Mask256 mask) { + const Full256 d; + const Rebind di32; + const RebindToUnsigned du32; + const Mask512 mask32{static_cast<__mmask16>(mask.raw)}; + const Vec512 v32 = BitCast(du32, PromoteTo(di32, v)); + return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); +} + +// See above - small-vector EmuCompressStore are implemented via EmuCompress. +template +HWY_INLINE void EmuCompressStore(Vec128 v, Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { + StoreU(EmuCompress(v, mask), d, unaligned); +} + +HWY_INLINE void EmuCompressStore(Vec256 v, Mask256 mask, + Full256 d, + uint16_t* HWY_RESTRICT unaligned) { + StoreU(EmuCompress(v, mask), d, unaligned); +} + +// Main emulation logic for wider vector, starting with EmuCompressStore because +// it is most convenient to merge pieces using memory (concatenating vectors at +// byte offsets is difficult). +HWY_INLINE void EmuCompressStore(Vec256 v, Mask256 mask, + Full256 d, + uint8_t* HWY_RESTRICT unaligned) { + const uint64_t mask_bits{mask.raw}; + const Half dh; + const Rebind d32; + const Vec512 v0 = PromoteTo(d32, LowerHalf(v)); + const Vec512 v1 = PromoteTo(d32, UpperHalf(dh, v)); + const Mask512 m0{static_cast<__mmask16>(mask_bits & 0xFFFFu)}; + const Mask512 m1{static_cast<__mmask16>(mask_bits >> 16)}; + const Vec128 c0 = TruncateTo(dh, NativeCompress(v0, m0)); + const Vec128 c1 = TruncateTo(dh, NativeCompress(v1, m1)); + uint8_t* HWY_RESTRICT pos = unaligned; + StoreU(c0, dh, pos); + StoreU(c1, dh, pos + CountTrue(d32, m0)); +} + +HWY_INLINE void EmuCompressStore(Vec512 v, Mask512 mask, + Full512 d, + uint8_t* HWY_RESTRICT unaligned) { + const uint64_t mask_bits{mask.raw}; + const Half> dq; + const Rebind d32; + HWY_ALIGN uint8_t lanes[64]; + Store(v, d, lanes); + const Vec512 v0 = PromoteTo(d32, LowerHalf(LowerHalf(v))); + const Vec512 v1 = PromoteTo(d32, Load(dq, lanes + 16)); + const Vec512 v2 = PromoteTo(d32, Load(dq, lanes + 32)); + const Vec512 v3 = PromoteTo(d32, Load(dq, lanes + 48)); + const Mask512 m0{static_cast<__mmask16>(mask_bits & 0xFFFFu)}; + const Mask512 m1{ + static_cast((mask_bits >> 16) & 0xFFFFu)}; + const Mask512 m2{ + static_cast((mask_bits >> 32) & 0xFFFFu)}; + const Mask512 m3{static_cast<__mmask16>(mask_bits >> 48)}; + const Vec128 c0 = TruncateTo(dq, NativeCompress(v0, m0)); + const Vec128 c1 = TruncateTo(dq, NativeCompress(v1, m1)); + const Vec128 c2 = TruncateTo(dq, NativeCompress(v2, m2)); + const Vec128 c3 = TruncateTo(dq, NativeCompress(v3, m3)); + uint8_t* HWY_RESTRICT pos = unaligned; + StoreU(c0, dq, pos); + pos += CountTrue(d32, m0); + StoreU(c1, dq, pos); + pos += CountTrue(d32, m1); + StoreU(c2, dq, pos); + pos += CountTrue(d32, m2); + StoreU(c3, dq, pos); +} + +HWY_INLINE void EmuCompressStore(Vec512 v, Mask512 mask, + Full512 d, + uint16_t* HWY_RESTRICT unaligned) { + const Repartition di32; + const RebindToUnsigned du32; + const Half dh; + const Vec512 promoted0 = + BitCast(du32, PromoteTo(di32, LowerHalf(dh, v))); + const Vec512 promoted1 = + BitCast(du32, PromoteTo(di32, UpperHalf(dh, v))); + + const uint64_t mask_bits{mask.raw}; + const uint64_t maskL = mask_bits & 0xFFFF; + const uint64_t maskH = mask_bits >> 16; + const Mask512 mask0{static_cast<__mmask16>(maskL)}; + const Mask512 mask1{static_cast<__mmask16>(maskH)}; + const Vec512 compressed0 = NativeCompress(promoted0, mask0); + const Vec512 compressed1 = NativeCompress(promoted1, mask1); + + const Vec256 demoted0 = DemoteTo(dh, BitCast(di32, compressed0)); + const Vec256 demoted1 = DemoteTo(dh, BitCast(di32, compressed1)); + + // Store 256-bit halves + StoreU(demoted0, dh, unaligned); + StoreU(demoted1, dh, unaligned + PopCount(maskL)); +} + +// Finally, the remaining EmuCompress for wide vectors, using EmuCompressStore. +template // 1 or 2 bytes +HWY_INLINE Vec512 EmuCompress(Vec512 v, Mask512 mask) { + const Full512 d; + HWY_ALIGN T buf[2 * 64 / sizeof(T)]; + EmuCompressStore(v, mask, d, buf); + return Load(d, buf); +} + +HWY_INLINE Vec256 EmuCompress(Vec256 v, + const Mask256 mask) { + const Full256 d; + HWY_ALIGN uint8_t buf[2 * 32 / sizeof(uint8_t)]; + EmuCompressStore(v, mask, d, buf); + return Load(d, buf); +} + +} // namespace detail + +template // 1 or 2 bytes +HWY_API V Compress(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +#else + return BitCast(d, detail::EmuCompress(BitCast(du, v), mu)); +#endif +} + +template +HWY_API V Compress(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +} + +template +HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) constexpr uint64_t packed_array[256] = { + // From PrintCompress32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, + 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, + 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, + 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, + 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, + 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, + 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, + 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, + 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, + 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, + 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, + 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, + 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, + 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, + 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, + 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, + 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, + 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, + 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, + 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, + 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, + 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, + 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, + 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, + 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, + 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, + 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, + 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, + 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, + 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, + 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, + 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, + 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, + 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, + 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, + 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, + 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, + 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, + 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, + 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, + 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, + 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, + 0x10765432, 0x17654320, 0x07654321, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const Full512 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot + +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API Vec512 CompressNot(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) constexpr uint64_t packed_array[256] = { + // From PrintCompressNot32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, + 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, + 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, + 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, + 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, + 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, + 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, + 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, + 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, + 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, + 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, + 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, + 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, + 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, + 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, + 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, + 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, + 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, + 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, + 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, + 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, + 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, + 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, + 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, + 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, + 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, + 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, + 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, + 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, + 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, + 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, + 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, + 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, + 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, + 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, + 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, + 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, + 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, + 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, + 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, + 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, + 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, + 0x76543210, 0x76543201, 0x76543210, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const Full512 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// uint64_t lanes. Only implement for 256 and 512-bit vectors because this is a +// no-op for 128-bit. +template 16)>* = nullptr> +HWY_API V CompressBlocksNot(V v, M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressBits +template +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +// ------------------------------ CompressStore + +template // 1 or 2 bytes +HWY_API size_t CompressStore(V v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + auto pu = reinterpret_cast * HWY_RESTRICT>(unaligned); +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + detail::NativeCompressStore(BitCast(du, v), mu, du, pu); +#else + detail::EmuCompressStore(BitCast(du, v), mu, du, pu); +#endif + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(pu, count); + return count; +} + +template // 4 or 8 +HWY_API size_t CompressStore(V v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + using TU = TFromD; + TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + detail::NativeCompressStore(BitCast(du, v), mu, du, pu); + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(pu, count); + return count; +} + +// Additional overloads to avoid casting to uint32_t (delay?). +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, + Full512 /* tag */, + float* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, + Full512 /* tag */, + double* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore +template > +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + T* HWY_RESTRICT unaligned) { + // Native CompressStore already does the blending at no extra cost (latency + // 11, rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) > 2) { + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ LoadInterleaved4 + +// Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. +namespace detail { + +// Type-safe wrapper. +template <_MM_PERM_ENUM kPerm, typename T> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_i64x2(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// Output: +// 9 6 3 0 (LSB of A) +// a 7 4 1 +// b 8 5 2 +template +HWY_API void LoadTransposedBlocks3(Full512 d, + const T* HWY_RESTRICT unaligned, + Vec512& A, Vec512& B, Vec512& C) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 v3210 = LoadU(d, unaligned + 0 * N); + const Vec512 v7654 = LoadU(d, unaligned + 1 * N); + const Vec512 vba98 = LoadU(d, unaligned + 2 * N); + + const Vec512 v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); + const Vec512 va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); + + A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); + B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); + C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// f e d c +// Output: +// c 8 4 0 (LSB of A) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +template +HWY_API void LoadTransposedBlocks4(Full512 d, + const T* HWY_RESTRICT unaligned, + Vec512& A, Vec512& B, Vec512& C, + Vec512& D) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 v3210 = LoadU(d, unaligned + 0 * N); + const Vec512 v7654 = LoadU(d, unaligned + 1 * N); + const Vec512 vba98 = LoadU(d, unaligned + 2 * N); + const Vec512 vfedc = LoadU(d, unaligned + 3 * N); + + const Vec512 v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); + const Vec512 vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); + const Vec512 v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); + const Vec512 vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); + A = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); + B = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); + C = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); + D = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2 + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 6 4 2 0 (LSB of i) +// 7 5 3 1 +// Output: +// 3 2 1 0 +// 7 6 5 4 +template +HWY_API void StoreTransposedBlocks2(const Vec512 i, const Vec512 j, + const Full512 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const auto j1_i1_j0_i0 = + detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); + const auto j3_i3_j2_i2 = + detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); + StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); + StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 9 6 3 0 (LSB of i) +// a 7 4 1 +// b 8 5 2 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +template +HWY_API void StoreTransposedBlocks3(const Vec512 i, const Vec512 j, + const Vec512 k, Full512 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); + const Vec512 i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); + const Vec512 j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); + + const Vec512 out0 = // i1 k0 j0 i0 + detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); + const Vec512 out1 = // j2 i2 k1 j1 + detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); + const Vec512 out2 = // k3 j3 i3 k2 + detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); + + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// c 8 4 0 (LSB of i) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +// f e d c +template +HWY_API void StoreTransposedBlocks4(const Vec512 i, const Vec512 j, + const Vec512 k, const Vec512 l, + Full512 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const Vec512 l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); + const Vec512 j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const Vec512 l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); + const Vec512 out0 = + detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); + const Vec512 out1 = + detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); + const Vec512 out2 = + detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); + const Vec512 out3 = + detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ MulEven/Odd (Shuffle2301, InterleaveLower) + +HWY_INLINE Vec512 MulEven(const Vec512 a, + const Vec512 b) { + const Full512 du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveLower(mulL, mulH); +} + +HWY_INLINE Vec512 MulOdd(const Vec512 a, + const Vec512 b) { + const Full512 du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need bits [95:64] (= upper half of input) + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Same as above, but we're using the odd results (upper 64 bits per block). + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveUpper(du64, mulL, mulH); +} + +// ------------------------------ ReorderWidenMulAccumulate +HWY_API Vec512 ReorderWidenMulAccumulate(Full512 /*d32*/, + Vec512 a, + Vec512 b, + const Vec512 sum0, + Vec512& /*sum1*/) { + return sum0 + Vec512{_mm512_madd_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512 RearrangeToOddPlusEven(const Vec512 sum0, + Vec512 /*sum1*/) { + return sum0; // invariant already holds +} + +// ------------------------------ Reductions + +// Returns the sum in each lane. +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_epi32(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_epi64(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, static_cast(_mm512_reduce_add_epi32(v.raw))); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, static_cast(_mm512_reduce_add_epi64(v.raw))); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_ps(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_pd(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(d32, even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(d32, even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +// Returns the minimum in each lane. +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epi32(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epi64(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epu32(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epu64(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_ps(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_pd(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(d32, Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(d32, Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// Returns the maximum in each lane. +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epi32(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epi64(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epu32(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epu64(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_ps(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_pd(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(d32, Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(d32, Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) -- cgit v1.2.3